onnx, fix depl issue
This commit is contained in:
@@ -21,7 +21,7 @@ Status legend: [ ] todo, [~] in-progress, [x] done
|
|||||||
### Performance Optimization
|
### Performance Optimization
|
||||||
- [x] 09 — Image Caching & Lazy Loading → `09-image-caching.md`
|
- [x] 09 — Image Caching & Lazy Loading → `09-image-caching.md`
|
||||||
- [x] 10 — Memory Management & Leak Audit → `10-memory-leak-audit.md`
|
- [x] 10 — Memory Management & Leak Audit → `10-memory-leak-audit.md`
|
||||||
- [~] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md`
|
- [x] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md`
|
||||||
- [x] 12 — App Launch Time Optimization → `12-launch-time.md`
|
- [x] 12 — App Launch Time Optimization → `12-launch-time.md`
|
||||||
|
|
||||||
### Native Features
|
### Native Features
|
||||||
|
|||||||
1
web/.gitignore
vendored
1
web/.gitignore
vendored
@@ -5,6 +5,7 @@ dist
|
|||||||
.netlify
|
.netlify
|
||||||
.vinxi
|
.vinxi
|
||||||
app.config.timestamp_*.js
|
app.config.timestamp_*.js
|
||||||
|
.pi-lens
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
.env*
|
.env*
|
||||||
|
|||||||
51
web/.vercelignore
Normal file
51
web/.vercelignore
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# ── ML Model (255MB ONNX model — too large for Vercel, downloaded at runtime) ──
|
||||||
|
src/server/models/spam-classifier/
|
||||||
|
|
||||||
|
# ── Build Artifacts ──
|
||||||
|
.output/
|
||||||
|
.nitro/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# ── Test Files (not needed in production) ──
|
||||||
|
e2e/
|
||||||
|
test/
|
||||||
|
**/*.test.ts
|
||||||
|
**/*.test.tsx
|
||||||
|
**/*.spec.ts
|
||||||
|
**/*.spec.tsx
|
||||||
|
|
||||||
|
# ── Development / Config ──
|
||||||
|
.dockerignore
|
||||||
|
Dockerfile
|
||||||
|
docker-compose.yml
|
||||||
|
docker-compose.yaml
|
||||||
|
vitest.config.ts
|
||||||
|
vitest.node.config.ts
|
||||||
|
playwright.config.ts
|
||||||
|
drizzle.config.ts
|
||||||
|
drizzle/
|
||||||
|
|
||||||
|
# ── Version Control ──
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.github/
|
||||||
|
.husky/
|
||||||
|
|
||||||
|
# ── Environment (already in .gitignore, being explicit) ──
|
||||||
|
.env
|
||||||
|
.env.development
|
||||||
|
.env.production
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# ── Editors / OS ──
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# ── Pi agent / dev tooling ──
|
||||||
|
.pi-lens/
|
||||||
|
.agents/
|
||||||
@@ -17,67 +17,67 @@ const __dirname = path.dirname(__filename);
|
|||||||
// ── Types ──────────────────────────────────────────────────────────────────
|
// ── Types ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
export interface TextClassification {
|
export interface TextClassification {
|
||||||
isSpam: boolean;
|
isSpam: boolean;
|
||||||
confidence: number;
|
confidence: number;
|
||||||
score: number;
|
score: number;
|
||||||
modelVersion?: string;
|
modelVersion?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ClassificationThresholds {
|
export interface ClassificationThresholds {
|
||||||
strict: number; // 0.3 - flag more aggressively
|
strict: number; // 0.3 - flag more aggressively
|
||||||
moderate: number; // 0.5 - balanced
|
moderate: number; // 0.5 - balanced
|
||||||
lenient: number; // 0.7 - fewer false positives
|
lenient: number; // 0.7 - fewer false positives
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ThresholdMode = "strict" | "moderate" | "lenient";
|
export type ThresholdMode = "strict" | "moderate" | "lenient";
|
||||||
|
|
||||||
const DEFAULT_THRESHOLDS: ClassificationThresholds = {
|
const DEFAULT_THRESHOLDS: ClassificationThresholds = {
|
||||||
strict: 0.3,
|
strict: 0.3,
|
||||||
moderate: 0.5,
|
moderate: 0.5,
|
||||||
lenient: 0.7,
|
lenient: 0.7,
|
||||||
};
|
};
|
||||||
|
|
||||||
// ── Model Singleton ────────────────────────────────────────────────────────
|
// ── Model Singleton ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
interface ModelState {
|
interface ModelState {
|
||||||
session: InferenceSession | null;
|
session: InferenceSession | null;
|
||||||
tokenizer: BertTokenizer;
|
tokenizer: BertTokenizer;
|
||||||
metadata: ModelMetadata;
|
metadata: ModelMetadata;
|
||||||
loaded: boolean;
|
loaded: boolean;
|
||||||
loadError: Error | null;
|
loadError: Error | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ModelMetadata {
|
interface ModelMetadata {
|
||||||
version: string;
|
version: string;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
task: string;
|
task: string;
|
||||||
max_length: number;
|
max_length: number;
|
||||||
num_labels: number;
|
num_labels: number;
|
||||||
label2id: Record<string, number>;
|
label2id: Record<string, number>;
|
||||||
id2label: Record<number, string>;
|
id2label: Record<number, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelState: ModelState = {
|
const modelState: ModelState = {
|
||||||
session: null,
|
session: null,
|
||||||
tokenizer: null as unknown as BertTokenizer,
|
tokenizer: null as unknown as BertTokenizer,
|
||||||
metadata: {
|
metadata: {
|
||||||
version: "0.0.0",
|
version: "0.0.0",
|
||||||
model_name: "",
|
model_name: "",
|
||||||
task: "",
|
task: "",
|
||||||
max_length: 128,
|
max_length: 128,
|
||||||
num_labels: 2,
|
num_labels: 2,
|
||||||
label2id: {},
|
label2id: {},
|
||||||
id2label: {},
|
id2label: {},
|
||||||
},
|
},
|
||||||
loaded: false,
|
loaded: false,
|
||||||
loadError: null,
|
loadError: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
// ── Result Cache ───────────────────────────────────────────────────────────
|
// ── Result Cache ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
interface CacheEntry {
|
interface CacheEntry {
|
||||||
result: TextClassification;
|
result: TextClassification;
|
||||||
timestamp: number;
|
timestamp: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
const resultCache = new Map<string, CacheEntry>();
|
const resultCache = new Map<string, CacheEntry>();
|
||||||
@@ -85,305 +85,471 @@ const CACHE_MAX_SIZE = 1000;
|
|||||||
const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
||||||
|
|
||||||
function cacheKey(text: string): string {
|
function cacheKey(text: string): string {
|
||||||
// Simple hash of normalized text
|
// Simple hash of normalized text
|
||||||
const normalized = text.toLowerCase().trim();
|
const normalized = text.toLowerCase().trim();
|
||||||
let hash = 0;
|
let hash = 0;
|
||||||
for (let i = 0; i < normalized.length; i++) {
|
for (let i = 0; i < normalized.length; i++) {
|
||||||
const char = normalized.charCodeAt(i);
|
const char = normalized.charCodeAt(i);
|
||||||
hash = ((hash << 5) - hash) + char;
|
hash = (hash << 5) - hash + char;
|
||||||
hash |= 0; // Convert to 32bit integer
|
hash |= 0; // Convert to 32bit integer
|
||||||
}
|
}
|
||||||
return String(hash);
|
return String(hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
function getCached(text: string): TextClassification | null {
|
function getCached(text: string): TextClassification | null {
|
||||||
const key = cacheKey(text);
|
const key = cacheKey(text);
|
||||||
const entry = resultCache.get(key);
|
const entry = resultCache.get(key);
|
||||||
if (!entry) return null;
|
if (!entry) return null;
|
||||||
if (Date.now() - entry.timestamp > CACHE_TTL_MS) {
|
if (Date.now() - entry.timestamp > CACHE_TTL_MS) {
|
||||||
resultCache.delete(key);
|
resultCache.delete(key);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return entry.result;
|
return entry.result;
|
||||||
}
|
}
|
||||||
|
|
||||||
function setCache(text: string, result: TextClassification): void {
|
function setCache(text: string, result: TextClassification): void {
|
||||||
if (resultCache.size >= CACHE_MAX_SIZE) {
|
if (resultCache.size >= CACHE_MAX_SIZE) {
|
||||||
// Evict oldest entry
|
// Evict oldest entry
|
||||||
const oldestKey = resultCache.keys().next().value;
|
const oldestKey = resultCache.keys().next().value;
|
||||||
resultCache.delete(oldestKey);
|
resultCache.delete(oldestKey);
|
||||||
}
|
}
|
||||||
resultCache.set(cacheKey(text), { result, timestamp: Date.now() });
|
resultCache.set(cacheKey(text), { result, timestamp: Date.now() });
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── BertTokenizer (JavaScript implementation) ──────────────────────────────
|
// ── BertTokenizer (JavaScript implementation) ──────────────────────────────
|
||||||
|
|
||||||
interface TokenizerConfig {
|
interface TokenizerConfig {
|
||||||
vocab: Map<string, number>;
|
vocab: Map<string, number>;
|
||||||
inv_vocab: Map<number, string>;
|
inv_vocab: Map<number, string>;
|
||||||
max_len: number;
|
max_len: number;
|
||||||
do_lower_case: boolean;
|
do_lower_case: boolean;
|
||||||
tokenizers: Record<string, unknown>;
|
tokenizers: Record<string, unknown>;
|
||||||
model_max_length: number;
|
model_max_length: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
class BertTokenizer {
|
class BertTokenizer {
|
||||||
private config: TokenizerConfig;
|
private config: TokenizerConfig;
|
||||||
|
|
||||||
constructor(configPath: string) {
|
constructor(configPath: string) {
|
||||||
this.config = this.loadConfig(configPath);
|
this.config = this.loadConfig(configPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
private loadConfig(configPath: string): TokenizerConfig {
|
private loadConfig(configPath: string): TokenizerConfig {
|
||||||
const vocabPath = path.join(configPath, "vocab.txt");
|
const vocabPath = path.join(configPath, "vocab.txt");
|
||||||
const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json");
|
const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json");
|
||||||
|
|
||||||
// Load vocabulary
|
// Load vocabulary
|
||||||
const vocab = new Map<string, number>();
|
const vocab = new Map<string, number>();
|
||||||
const inv_vocab = new Map<number, string>();
|
const inv_vocab = new Map<number, string>();
|
||||||
const vocabText = fs.readFileSync(vocabPath, "utf-8");
|
const vocabText = fs.readFileSync(vocabPath, "utf-8");
|
||||||
const lines = vocabText.split("\n");
|
const lines = vocabText.split("\n");
|
||||||
for (let i = 0; i < lines.length; i++) {
|
for (let i = 0; i < lines.length; i++) {
|
||||||
const token = lines[i].trim();
|
const token = lines[i].trim();
|
||||||
if (token) {
|
if (token) {
|
||||||
vocab.set(token, i);
|
vocab.set(token, i);
|
||||||
inv_vocab.set(i, token);
|
inv_vocab.set(i, token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load tokenizer config
|
// Load tokenizer config
|
||||||
let doLowercase = true;
|
let doLowercase = true;
|
||||||
let modelMaxLength = 512;
|
let modelMaxLength = 512;
|
||||||
try {
|
try {
|
||||||
const configData = JSON.parse(fs.readFileSync(tokenizerConfigPath, "utf-8"));
|
const configData = JSON.parse(
|
||||||
doLowercase = configData.do_lower_case ?? true;
|
fs.readFileSync(tokenizerConfigPath, "utf-8"),
|
||||||
modelMaxLength = configData.model_max_length ?? 512;
|
);
|
||||||
} catch {
|
doLowercase = configData.do_lower_case ?? true;
|
||||||
// Use defaults
|
modelMaxLength = configData.model_max_length ?? 512;
|
||||||
}
|
} catch {
|
||||||
|
// Use defaults
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
vocab,
|
vocab,
|
||||||
inv_vocab,
|
inv_vocab,
|
||||||
max_len: 512,
|
max_len: 512,
|
||||||
do_lower_case: doLowercase,
|
do_lower_case: doLowercase,
|
||||||
tokenizers: {},
|
tokenizers: {},
|
||||||
model_max_length: modelMaxLength,
|
model_max_length: modelMaxLength,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private whitespace_tokenize(text: string): string[] {
|
private whitespace_tokenize(text: string): string[] {
|
||||||
if (this.config.do_lower_case) {
|
if (this.config.do_lower_case) {
|
||||||
text = text.toLowerCase();
|
text = text.toLowerCase();
|
||||||
}
|
}
|
||||||
// Split on whitespace, keeping punctuation attached
|
// Split on whitespace, keeping punctuation attached
|
||||||
return text.split(/\s+/).filter((t) => t.length > 0);
|
return text.split(/\s+/).filter((t) => t.length > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private wordpiece_tokenize(token: string, maxOutputTokens: number = 20): string[] {
|
private wordpiece_tokenize(
|
||||||
const outputTokens: string[] = [];
|
token: string,
|
||||||
let isBad = false;
|
maxOutputTokens: number = 20,
|
||||||
let start = 0;
|
): string[] {
|
||||||
let subToken: string | null = null;
|
const outputTokens: string[] = [];
|
||||||
|
let isBad = false;
|
||||||
|
let start = 0;
|
||||||
|
let subToken: string | null = null;
|
||||||
|
|
||||||
while (start < token.length && !isBad && outputTokens.length < maxOutputTokens) {
|
while (
|
||||||
let found = false;
|
start < token.length &&
|
||||||
|
!isBad &&
|
||||||
|
outputTokens.length < maxOutputTokens
|
||||||
|
) {
|
||||||
|
let found = false;
|
||||||
|
|
||||||
for (let end = token.length; end > start; end--) {
|
for (let end = token.length; end > start; end--) {
|
||||||
let substr = token.substring(start, end);
|
let substr = token.substring(start, end);
|
||||||
if (start > 0) {
|
if (start > 0) {
|
||||||
substr = "##" + substr;
|
substr = "##" + substr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.config.vocab.has(substr)) {
|
if (this.config.vocab.has(substr)) {
|
||||||
outputTokens.push(substr);
|
outputTokens.push(substr);
|
||||||
subToken = substr;
|
subToken = substr;
|
||||||
start = end;
|
start = end;
|
||||||
found = true;
|
found = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!found) {
|
if (!found) {
|
||||||
isBad = true;
|
isBad = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isBad) {
|
if (isBad) {
|
||||||
outputTokens.push("[UNK]");
|
outputTokens.push("[UNK]");
|
||||||
} else if (subToken === null) {
|
} else if (subToken === null) {
|
||||||
outputTokens.push("[UNK]");
|
outputTokens.push("[UNK]");
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputTokens;
|
return outputTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
private tokenize(text: string): string[] {
|
private tokenize(text: string): string[] {
|
||||||
const tokens = [];
|
const tokens = [];
|
||||||
const whitespaceTokens = this.whitespace_tokenize(text);
|
const whitespaceTokens = this.whitespace_tokenize(text);
|
||||||
|
|
||||||
for (const token of whitespaceTokens) {
|
for (const token of whitespaceTokens) {
|
||||||
const subTokens = this.wordpiece_tokenize(token);
|
const subTokens = this.wordpiece_tokenize(token);
|
||||||
tokens.push(...subTokens);
|
tokens.push(...subTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
encode(text: string, maxLen: number = 128): { inputIds: number[]; attentionMask: number[] } {
|
encode(
|
||||||
const tokens = this.tokenize(text);
|
text: string,
|
||||||
|
maxLen: number = 128,
|
||||||
|
): { inputIds: number[]; attentionMask: number[] } {
|
||||||
|
const tokens = this.tokenize(text);
|
||||||
|
|
||||||
// Add [CLS] and [SEP]
|
// Add [CLS] and [SEP]
|
||||||
const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"];
|
const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"];
|
||||||
|
|
||||||
const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK]
|
const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK]
|
||||||
const attentionMask = new Array(inputIds.length).fill(1);
|
const attentionMask = new Array(inputIds.length).fill(1);
|
||||||
|
|
||||||
// Pad to maxLen if needed
|
// Pad to maxLen if needed
|
||||||
while (inputIds.length < maxLen) {
|
while (inputIds.length < maxLen) {
|
||||||
inputIds.push(0);
|
inputIds.push(0);
|
||||||
attentionMask.push(0);
|
attentionMask.push(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return { inputIds, attentionMask };
|
return { inputIds, attentionMask };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Model Loading ──────────────────────────────────────────────────────────
|
// ── Model Loading ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
const MODEL_DIR_ENV = "SPAM_MODEL_DIR";
|
const MODEL_DIR_ENV = "SPAM_MODEL_DIR";
|
||||||
const DEFAULT_MODEL_DIR = path.join(__dirname, "..", "..", "models", "spam-classifier");
|
const DEFAULT_MODEL_DIR = path.join(
|
||||||
|
__dirname,
|
||||||
|
"..",
|
||||||
|
"..",
|
||||||
|
"models",
|
||||||
|
"spam-classifier",
|
||||||
|
);
|
||||||
|
|
||||||
function getModelDir(): string {
|
function getModelDir(): string {
|
||||||
return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR;
|
return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Remote Model Download ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const MODEL_DOWNLOAD_URL_ENV = "SPAM_MODEL_URL_BASE";
|
||||||
|
|
||||||
|
/** Model files that need to be available in the model directory. */
|
||||||
|
const MODEL_FILES = [
|
||||||
|
"model.onnx",
|
||||||
|
"model.onnx.data",
|
||||||
|
"tokenizer.json",
|
||||||
|
"vocab.txt",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
"special_tokens_map.json",
|
||||||
|
"model_metadata.json",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if all required model files exist in the given directory.
|
||||||
|
*/
|
||||||
|
function modelFilesExist(dir: string): boolean {
|
||||||
|
try {
|
||||||
|
return MODEL_FILES.every((f) => fs.existsSync(path.join(dir, f)));
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download a single model file from a remote URL to a local path.
|
||||||
|
* Uses streaming to handle large files (e.g., model.onnx.data at 255MB).
|
||||||
|
*/
|
||||||
|
async function downloadModelFile(url: string, destPath: string): Promise<void> {
|
||||||
|
const response = await fetch(url);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(
|
||||||
|
`Failed to download ${url}: ${response.status} ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body?.getReader();
|
||||||
|
if (!reader) {
|
||||||
|
throw new Error(`No response body stream for ${url}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure parent directory exists
|
||||||
|
const dir = path.dirname(destPath);
|
||||||
|
fs.mkdirSync(dir, { recursive: true });
|
||||||
|
|
||||||
|
// Stream to file
|
||||||
|
const writer = fs.createWriteStream(destPath);
|
||||||
|
try {
|
||||||
|
let totalBytes = 0;
|
||||||
|
let lastLog = 0;
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
writer.write(value);
|
||||||
|
totalBytes += value.length;
|
||||||
|
|
||||||
|
// Log progress every ~10MB
|
||||||
|
if (totalBytes - lastLog > 10 * 1024 * 1024) {
|
||||||
|
lastLog = totalBytes;
|
||||||
|
const mb = (totalBytes / (1024 * 1024)).toFixed(1);
|
||||||
|
console.log(
|
||||||
|
`[spamshield] Downloaded ${path.basename(destPath)}: ${mb}MB`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
writer.end();
|
||||||
|
await new Promise<void>((resolve) => writer.on("finish", resolve));
|
||||||
|
}
|
||||||
|
|
||||||
|
const totalMB = (fs.statSync(destPath).size / (1024 * 1024)).toFixed(1);
|
||||||
|
console.log(
|
||||||
|
`[spamshield] Downloaded ${path.basename(destPath)} (${totalMB}MB)`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download all model files from a remote URL base to the model directory.
|
||||||
|
* Falls back gracefully — if the URL is not configured, returns false.
|
||||||
|
*/
|
||||||
|
async function downloadModelIfMissing(modelDir: string): Promise<boolean> {
|
||||||
|
// If model files already exist locally, nothing to do
|
||||||
|
if (modelFilesExist(modelDir)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseUrl = process.env[MODEL_DOWNLOAD_URL_ENV];
|
||||||
|
if (!baseUrl) {
|
||||||
|
console.log(
|
||||||
|
"[spamshield] Model files not found locally and SPAM_MODEL_URL_BASE not set — " +
|
||||||
|
"will use rule-engine fallback",
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedBase = baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`;
|
||||||
|
console.log(`[spamshield] Downloading model from: ${normalizedBase}`);
|
||||||
|
|
||||||
|
// Ensure model directory exists
|
||||||
|
fs.mkdirSync(modelDir, { recursive: true });
|
||||||
|
|
||||||
|
// Track which files we already have (for caching across cold starts)
|
||||||
|
const existing = new Set<string>();
|
||||||
|
for (const file of MODEL_FILES) {
|
||||||
|
const filePath = path.join(modelDir, file);
|
||||||
|
if (fs.existsSync(filePath) && fs.statSync(filePath).size > 0) {
|
||||||
|
existing.add(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download missing files
|
||||||
|
for (const file of MODEL_FILES) {
|
||||||
|
if (existing.has(file)) {
|
||||||
|
console.log(`[spamshield] Already have ${file}, skipping download`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const url = `${normalizedBase}${file}`;
|
||||||
|
const destPath = path.join(modelDir, file);
|
||||||
|
console.log(`[spamshield] Downloading ${file}...`);
|
||||||
|
try {
|
||||||
|
await downloadModelFile(url, destPath);
|
||||||
|
} catch (err) {
|
||||||
|
console.error(`[spamshield] Failed to download ${file}:`, err);
|
||||||
|
// If the main model files fail, we can't use the model
|
||||||
|
if (file === "model.onnx" || file === "model.onnx.data") {
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelFilesExist(modelDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadModel(): Promise<void> {
|
async function loadModel(): Promise<void> {
|
||||||
if (modelState.loaded) return;
|
if (modelState.loaded) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const modelDir = getModelDir();
|
const modelDir = getModelDir();
|
||||||
console.log(`[spamshield] Loading ONNX model from: ${modelDir}`);
|
console.log(`[spamshield] Loading ONNX model from: ${modelDir}`);
|
||||||
|
|
||||||
// Load metadata
|
// Download model files if missing (production/Vercel path)
|
||||||
const metadataPath = path.join(modelDir, "model_metadata.json");
|
await downloadModelIfMissing(modelDir);
|
||||||
if (fs.existsSync(metadataPath)) {
|
|
||||||
modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8"));
|
|
||||||
console.log(`[spamshield] Model version: ${modelState.metadata.version}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load tokenizer
|
// Load metadata
|
||||||
modelState.tokenizer = new BertTokenizer(modelDir);
|
const metadataPath = path.join(modelDir, "model_metadata.json");
|
||||||
console.log("[spamshield] Tokenizer loaded");
|
if (fs.existsSync(metadataPath)) {
|
||||||
|
modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8"));
|
||||||
|
console.log(`[spamshield] Model version: ${modelState.metadata.version}`);
|
||||||
|
}
|
||||||
|
|
||||||
// Load ONNX model
|
// Load tokenizer
|
||||||
const modelPath = path.join(modelDir, "model.onnx");
|
modelState.tokenizer = new BertTokenizer(modelDir);
|
||||||
if (!fs.existsSync(modelPath)) {
|
console.log("[spamshield] Tokenizer loaded");
|
||||||
// Check for external data file
|
|
||||||
const modelDataPath = path.join(modelDir, "model.onnx.data");
|
|
||||||
if (!fs.existsSync(modelDataPath)) {
|
|
||||||
throw new Error(`ONNX model not found at ${modelPath}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
modelState.session = await ort.InferenceSession.create(modelPath);
|
// Load ONNX model
|
||||||
console.log("[spamshield] ONNX session created");
|
const modelPath = path.join(modelDir, "model.onnx");
|
||||||
console.log(`[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`);
|
if (!fs.existsSync(modelPath)) {
|
||||||
console.log(`[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`);
|
// Check for external data file
|
||||||
|
const modelDataPath = path.join(modelDir, "model.onnx.data");
|
||||||
|
if (!fs.existsSync(modelDataPath)) {
|
||||||
|
throw new Error(`ONNX model not found at ${modelPath}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
modelState.loaded = true;
|
modelState.session = await ort.InferenceSession.create(modelPath);
|
||||||
console.log("[spamshield] Model loaded successfully");
|
console.log("[spamshield] ONNX session created");
|
||||||
} catch (err) {
|
console.log(
|
||||||
modelState.loadError = err instanceof Error ? err : new Error(String(err));
|
`[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`,
|
||||||
console.error("[spamshield] Failed to load ONNX model:", modelState.loadError);
|
);
|
||||||
console.log("[spamshield] Falling back to rule engine for classification");
|
console.log(
|
||||||
}
|
`[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
modelState.loaded = true;
|
||||||
|
console.log("[spamshield] Model loaded successfully");
|
||||||
|
} catch (err) {
|
||||||
|
modelState.loadError = err instanceof Error ? err : new Error(String(err));
|
||||||
|
console.error(
|
||||||
|
"[spamshield] Failed to load ONNX model:",
|
||||||
|
modelState.loadError,
|
||||||
|
);
|
||||||
|
console.log("[spamshield] Falling back to rule engine for classification");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Inference ──────────────────────────────────────────────────────────────
|
// ── Inference ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
function sigmoid(x: number): number {
|
function sigmoid(x: number): number {
|
||||||
return 1 / (1 + Math.exp(-x));
|
return 1 / (1 + Math.exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runInference(
|
async function runInference(
|
||||||
text: string,
|
text: string,
|
||||||
thresholdMode: ThresholdMode = "moderate",
|
thresholdMode: ThresholdMode = "moderate",
|
||||||
): Promise<TextClassification> {
|
): Promise<TextClassification> {
|
||||||
const thresholds = DEFAULT_THRESHOLDS;
|
const thresholds = DEFAULT_THRESHOLDS;
|
||||||
const threshold = thresholds[thresholdMode];
|
const threshold = thresholds[thresholdMode];
|
||||||
|
|
||||||
// Check cache first
|
// Check cache first
|
||||||
const cached = getCached(text);
|
const cached = getCached(text);
|
||||||
if (cached) {
|
if (cached) {
|
||||||
return { ...cached, modelVersion: modelState.metadata.version };
|
return { ...cached, modelVersion: modelState.metadata.version };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure model is loaded
|
// Ensure model is loaded
|
||||||
if (!modelState.loaded || !modelState.session) {
|
if (!modelState.loaded || !modelState.session) {
|
||||||
await loadModel();
|
await loadModel();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If model still not loaded, return fallback
|
// If model still not loaded, return fallback
|
||||||
if (!modelState.loaded || !modelState.session) {
|
if (!modelState.loaded || !modelState.session) {
|
||||||
const fallback: TextClassification = {
|
const fallback: TextClassification = {
|
||||||
isSpam: false,
|
isSpam: false,
|
||||||
confidence: 0,
|
confidence: 0,
|
||||||
score: 0,
|
score: 0,
|
||||||
modelVersion: "fallback",
|
modelVersion: "fallback",
|
||||||
};
|
};
|
||||||
setCache(text, fallback);
|
setCache(text, fallback);
|
||||||
return fallback;
|
return fallback;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenize
|
// Tokenize
|
||||||
const maxLen = modelState.metadata.max_length || 128;
|
const maxLen = modelState.metadata.max_length || 128;
|
||||||
const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen);
|
const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen);
|
||||||
|
|
||||||
// Create ONNX tensors (int64 requires BigInt values)
|
// Create ONNX tensors (int64 requires BigInt values)
|
||||||
const inputIdsBigInt = new BigInt64Array(inputIds.length);
|
const inputIdsBigInt = new BigInt64Array(inputIds.length);
|
||||||
for (let i = 0; i < inputIds.length; i++) {
|
for (let i = 0; i < inputIds.length; i++) {
|
||||||
inputIdsBigInt[i] = BigInt(inputIds[i]);
|
inputIdsBigInt[i] = BigInt(inputIds[i]);
|
||||||
}
|
}
|
||||||
const attentionMaskBigInt = new BigInt64Array(attentionMask.length);
|
const attentionMaskBigInt = new BigInt64Array(attentionMask.length);
|
||||||
for (let i = 0; i < attentionMask.length; i++) {
|
for (let i = 0; i < attentionMask.length; i++) {
|
||||||
attentionMaskBigInt[i] = BigInt(attentionMask[i]);
|
attentionMaskBigInt[i] = BigInt(attentionMask[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]);
|
const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]);
|
||||||
const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [1, maxLen]);
|
const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [
|
||||||
|
1,
|
||||||
|
maxLen,
|
||||||
|
]);
|
||||||
|
|
||||||
// Run inference
|
// Run inference
|
||||||
const feeds: Record<string, Tensor> = {
|
const feeds: Record<string, Tensor> = {
|
||||||
input_ids: inputIdsTensor,
|
input_ids: inputIdsTensor,
|
||||||
attention_mask: attentionMaskTensor,
|
attention_mask: attentionMaskTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
const outputs = await modelState.session.run(feeds);
|
const outputs = await modelState.session.run(feeds);
|
||||||
const logits = outputs[modelState.session.outputNames[0]];
|
const logits = outputs[modelState.session.outputNames[0]];
|
||||||
|
|
||||||
// Extract logits (shape: [1, num_labels])
|
// Extract logits (shape: [1, num_labels])
|
||||||
const logitsData = logits.data as Float32Array | number[];
|
const logitsData = logits.data as Float32Array | number[];
|
||||||
const spamLogit = logitsData[1] ?? 0;
|
const spamLogit = logitsData[1] ?? 0;
|
||||||
const hamLogit = logitsData[0] ?? 0;
|
const hamLogit = logitsData[0] ?? 0;
|
||||||
|
|
||||||
// Apply sigmoid to get probability
|
// Apply sigmoid to get probability
|
||||||
const spamProb = sigmoid(spamLogit);
|
const spamProb = sigmoid(spamLogit);
|
||||||
const hamProb = sigmoid(hamLogit);
|
const hamProb = sigmoid(hamLogit);
|
||||||
|
|
||||||
// Binary decision based on threshold
|
// Binary decision based on threshold
|
||||||
const isSpam = spamProb >= threshold;
|
const isSpam = spamProb >= threshold;
|
||||||
const confidence = isSpam ? spamProb : 1 - spamProb;
|
const confidence = isSpam ? spamProb : 1 - spamProb;
|
||||||
|
|
||||||
const result: TextClassification = {
|
const result: TextClassification = {
|
||||||
isSpam,
|
isSpam,
|
||||||
confidence: Math.round(confidence * 10000) / 10000,
|
confidence: Math.round(confidence * 10000) / 10000,
|
||||||
score: Math.round(spamProb * 10000) / 10000,
|
score: Math.round(spamProb * 10000) / 10000,
|
||||||
modelVersion: modelState.metadata.version,
|
modelVersion: modelState.metadata.version,
|
||||||
};
|
};
|
||||||
|
|
||||||
setCache(text, result);
|
setCache(text, result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Public API ─────────────────────────────────────────────────────────────
|
// ── Public API ─────────────────────────────────────────────────────────────
|
||||||
@@ -393,21 +559,21 @@ async function runInference(
|
|||||||
* Falls back to returning a safe default if the model fails to load.
|
* Falls back to returning a safe default if the model fails to load.
|
||||||
*/
|
*/
|
||||||
export async function classifyTextBERT(
|
export async function classifyTextBERT(
|
||||||
text: string,
|
text: string,
|
||||||
thresholdMode: ThresholdMode = "moderate",
|
thresholdMode: ThresholdMode = "moderate",
|
||||||
): Promise<TextClassification> {
|
): Promise<TextClassification> {
|
||||||
try {
|
try {
|
||||||
return await runInference(text, thresholdMode);
|
return await runInference(text, thresholdMode);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("[spamshield] ONNX inference error:", err);
|
console.error("[spamshield] ONNX inference error:", err);
|
||||||
// Graceful fallback: return non-spam with low confidence
|
// Graceful fallback: return non-spam with low confidence
|
||||||
return {
|
return {
|
||||||
isSpam: false,
|
isSpam: false,
|
||||||
confidence: 0,
|
confidence: 0,
|
||||||
score: 0,
|
score: 0,
|
||||||
modelVersion: "error",
|
modelVersion: "error",
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -415,41 +581,41 @@ export async function classifyTextBERT(
|
|||||||
* Call this once during server initialization.
|
* Call this once during server initialization.
|
||||||
*/
|
*/
|
||||||
export async function initSpamModel(): Promise<boolean> {
|
export async function initSpamModel(): Promise<boolean> {
|
||||||
await loadModel();
|
await loadModel();
|
||||||
return modelState.loaded;
|
return modelState.loaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the model is loaded and ready.
|
* Check if the model is loaded and ready.
|
||||||
*/
|
*/
|
||||||
export function isModelLoaded(): boolean {
|
export function isModelLoaded(): boolean {
|
||||||
return modelState.loaded && modelState.session !== null;
|
return modelState.loaded && modelState.session !== null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get model metadata.
|
* Get model metadata.
|
||||||
*/
|
*/
|
||||||
export function getModelInfo(): ModelMetadata {
|
export function getModelInfo(): ModelMetadata {
|
||||||
return { ...modelState.metadata };
|
return { ...modelState.metadata };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the current cache stats.
|
* Get the current cache stats.
|
||||||
*/
|
*/
|
||||||
export function getCacheStats(): { size: number; max: number } {
|
export function getCacheStats(): { size: number; max: number } {
|
||||||
return { size: resultCache.size, max: CACHE_MAX_SIZE };
|
return { size: resultCache.size, max: CACHE_MAX_SIZE };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clear the result cache.
|
* Clear the result cache.
|
||||||
*/
|
*/
|
||||||
export function clearCache(): void {
|
export function clearCache(): void {
|
||||||
resultCache.clear();
|
resultCache.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get available threshold modes and their values.
|
* Get available threshold modes and their values.
|
||||||
*/
|
*/
|
||||||
export function getThresholds(): ClassificationThresholds {
|
export function getThresholds(): ClassificationThresholds {
|
||||||
return { ...DEFAULT_THRESHOLDS };
|
return { ...DEFAULT_THRESHOLDS };
|
||||||
}
|
}
|
||||||
|
|||||||
8
web/vercel.json
Normal file
8
web/vercel.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://openapi.vercel.sh/vercel.json",
|
||||||
|
"framework": "solidstart",
|
||||||
|
"buildCommand": "npm run build",
|
||||||
|
"installCommand": "npm install",
|
||||||
|
"outputDirectory": ".output/public",
|
||||||
|
"regions": ["iad1"]
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user