From 61d48d36482ced73ca135dfcce907fa39fe8248a Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Wed, 3 Jun 2026 13:35:37 -0400 Subject: [PATCH] onnx, fix depl issue --- tasks/ios-production/README.md | 2 +- web/.gitignore | 1 + web/.vercelignore | 51 ++ .../services/spamshield/onnx.inference.ts | 734 +++++++++++------- web/vercel.json | 8 + 5 files changed, 511 insertions(+), 285 deletions(-) create mode 100644 web/.vercelignore create mode 100644 web/vercel.json diff --git a/tasks/ios-production/README.md b/tasks/ios-production/README.md index 53b6202..d84c0e6 100644 --- a/tasks/ios-production/README.md +++ b/tasks/ios-production/README.md @@ -21,7 +21,7 @@ Status legend: [ ] todo, [~] in-progress, [x] done ### Performance Optimization - [x] 09 — Image Caching & Lazy Loading → `09-image-caching.md` - [x] 10 — Memory Management & Leak Audit → `10-memory-leak-audit.md` -- [~] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md` +- [x] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md` - [x] 12 — App Launch Time Optimization → `12-launch-time.md` ### Native Features diff --git a/web/.gitignore b/web/.gitignore index 4c0df46..d2311eb 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -5,6 +5,7 @@ dist .netlify .vinxi app.config.timestamp_*.js +.pi-lens # Environment .env* diff --git a/web/.vercelignore b/web/.vercelignore new file mode 100644 index 0000000..2a03727 --- /dev/null +++ b/web/.vercelignore @@ -0,0 +1,51 @@ +# ── ML Model (255MB ONNX model — too large for Vercel, downloaded at runtime) ── +src/server/models/spam-classifier/ + +# ── Build Artifacts ── +.output/ +.nitro/ +dist/ + +# ── Test Files (not needed in production) ── +e2e/ +test/ +**/*.test.ts +**/*.test.tsx +**/*.spec.ts +**/*.spec.tsx + +# ── Development / Config ── +.dockerignore +Dockerfile +docker-compose.yml +docker-compose.yaml +vitest.config.ts +vitest.node.config.ts +playwright.config.ts +drizzle.config.ts +drizzle/ + +# ── Version Control ── +.git/ +.gitignore +.github/ +.husky/ + +# ── Environment (already in .gitignore, being explicit) ── +.env +.env.development +.env.production +.env.local + +# ── Editors / OS ── +.idea/ +.vscode/ +*.swp +*.swo +*~ +.DS_Store +Thumbs.db + +# ── Pi agent / dev tooling ── +.pi-lens/ +.agents/ diff --git a/web/src/server/services/spamshield/onnx.inference.ts b/web/src/server/services/spamshield/onnx.inference.ts index ccbafce..a9fd16e 100644 --- a/web/src/server/services/spamshield/onnx.inference.ts +++ b/web/src/server/services/spamshield/onnx.inference.ts @@ -17,67 +17,67 @@ const __dirname = path.dirname(__filename); // ── Types ────────────────────────────────────────────────────────────────── export interface TextClassification { - isSpam: boolean; - confidence: number; - score: number; - modelVersion?: string; + isSpam: boolean; + confidence: number; + score: number; + modelVersion?: string; } export interface ClassificationThresholds { - strict: number; // 0.3 - flag more aggressively - moderate: number; // 0.5 - balanced - lenient: number; // 0.7 - fewer false positives + strict: number; // 0.3 - flag more aggressively + moderate: number; // 0.5 - balanced + lenient: number; // 0.7 - fewer false positives } export type ThresholdMode = "strict" | "moderate" | "lenient"; const DEFAULT_THRESHOLDS: ClassificationThresholds = { - strict: 0.3, - moderate: 0.5, - lenient: 0.7, + strict: 0.3, + moderate: 0.5, + lenient: 0.7, }; // ── Model Singleton ──────────────────────────────────────────────────────── interface ModelState { - session: InferenceSession | null; - tokenizer: BertTokenizer; - metadata: ModelMetadata; - loaded: boolean; - loadError: Error | null; + session: InferenceSession | null; + tokenizer: BertTokenizer; + metadata: ModelMetadata; + loaded: boolean; + loadError: Error | null; } interface ModelMetadata { - version: string; - model_name: string; - task: string; - max_length: number; - num_labels: number; - label2id: Record; - id2label: Record; + version: string; + model_name: string; + task: string; + max_length: number; + num_labels: number; + label2id: Record; + id2label: Record; } const modelState: ModelState = { - session: null, - tokenizer: null as unknown as BertTokenizer, - metadata: { - version: "0.0.0", - model_name: "", - task: "", - max_length: 128, - num_labels: 2, - label2id: {}, - id2label: {}, - }, - loaded: false, - loadError: null, + session: null, + tokenizer: null as unknown as BertTokenizer, + metadata: { + version: "0.0.0", + model_name: "", + task: "", + max_length: 128, + num_labels: 2, + label2id: {}, + id2label: {}, + }, + loaded: false, + loadError: null, }; // ── Result Cache ─────────────────────────────────────────────────────────── interface CacheEntry { - result: TextClassification; - timestamp: number; + result: TextClassification; + timestamp: number; } const resultCache = new Map(); @@ -85,305 +85,471 @@ const CACHE_MAX_SIZE = 1000; const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes function cacheKey(text: string): string { - // Simple hash of normalized text - const normalized = text.toLowerCase().trim(); - let hash = 0; - for (let i = 0; i < normalized.length; i++) { - const char = normalized.charCodeAt(i); - hash = ((hash << 5) - hash) + char; - hash |= 0; // Convert to 32bit integer - } - return String(hash); + // Simple hash of normalized text + const normalized = text.toLowerCase().trim(); + let hash = 0; + for (let i = 0; i < normalized.length; i++) { + const char = normalized.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash |= 0; // Convert to 32bit integer + } + return String(hash); } function getCached(text: string): TextClassification | null { - const key = cacheKey(text); - const entry = resultCache.get(key); - if (!entry) return null; - if (Date.now() - entry.timestamp > CACHE_TTL_MS) { - resultCache.delete(key); - return null; - } - return entry.result; + const key = cacheKey(text); + const entry = resultCache.get(key); + if (!entry) return null; + if (Date.now() - entry.timestamp > CACHE_TTL_MS) { + resultCache.delete(key); + return null; + } + return entry.result; } function setCache(text: string, result: TextClassification): void { - if (resultCache.size >= CACHE_MAX_SIZE) { - // Evict oldest entry - const oldestKey = resultCache.keys().next().value; - resultCache.delete(oldestKey); - } - resultCache.set(cacheKey(text), { result, timestamp: Date.now() }); + if (resultCache.size >= CACHE_MAX_SIZE) { + // Evict oldest entry + const oldestKey = resultCache.keys().next().value; + resultCache.delete(oldestKey); + } + resultCache.set(cacheKey(text), { result, timestamp: Date.now() }); } // ── BertTokenizer (JavaScript implementation) ────────────────────────────── interface TokenizerConfig { - vocab: Map; - inv_vocab: Map; - max_len: number; - do_lower_case: boolean; - tokenizers: Record; - model_max_length: number; + vocab: Map; + inv_vocab: Map; + max_len: number; + do_lower_case: boolean; + tokenizers: Record; + model_max_length: number; } class BertTokenizer { - private config: TokenizerConfig; + private config: TokenizerConfig; - constructor(configPath: string) { - this.config = this.loadConfig(configPath); - } + constructor(configPath: string) { + this.config = this.loadConfig(configPath); + } - private loadConfig(configPath: string): TokenizerConfig { - const vocabPath = path.join(configPath, "vocab.txt"); - const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json"); + private loadConfig(configPath: string): TokenizerConfig { + const vocabPath = path.join(configPath, "vocab.txt"); + const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json"); - // Load vocabulary - const vocab = new Map(); - const inv_vocab = new Map(); - const vocabText = fs.readFileSync(vocabPath, "utf-8"); - const lines = vocabText.split("\n"); - for (let i = 0; i < lines.length; i++) { - const token = lines[i].trim(); - if (token) { - vocab.set(token, i); - inv_vocab.set(i, token); - } - } + // Load vocabulary + const vocab = new Map(); + const inv_vocab = new Map(); + const vocabText = fs.readFileSync(vocabPath, "utf-8"); + const lines = vocabText.split("\n"); + for (let i = 0; i < lines.length; i++) { + const token = lines[i].trim(); + if (token) { + vocab.set(token, i); + inv_vocab.set(i, token); + } + } - // Load tokenizer config - let doLowercase = true; - let modelMaxLength = 512; - try { - const configData = JSON.parse(fs.readFileSync(tokenizerConfigPath, "utf-8")); - doLowercase = configData.do_lower_case ?? true; - modelMaxLength = configData.model_max_length ?? 512; - } catch { - // Use defaults - } + // Load tokenizer config + let doLowercase = true; + let modelMaxLength = 512; + try { + const configData = JSON.parse( + fs.readFileSync(tokenizerConfigPath, "utf-8"), + ); + doLowercase = configData.do_lower_case ?? true; + modelMaxLength = configData.model_max_length ?? 512; + } catch { + // Use defaults + } - return { - vocab, - inv_vocab, - max_len: 512, - do_lower_case: doLowercase, - tokenizers: {}, - model_max_length: modelMaxLength, - }; - } + return { + vocab, + inv_vocab, + max_len: 512, + do_lower_case: doLowercase, + tokenizers: {}, + model_max_length: modelMaxLength, + }; + } - private whitespace_tokenize(text: string): string[] { - if (this.config.do_lower_case) { - text = text.toLowerCase(); - } - // Split on whitespace, keeping punctuation attached - return text.split(/\s+/).filter((t) => t.length > 0); - } + private whitespace_tokenize(text: string): string[] { + if (this.config.do_lower_case) { + text = text.toLowerCase(); + } + // Split on whitespace, keeping punctuation attached + return text.split(/\s+/).filter((t) => t.length > 0); + } - private wordpiece_tokenize(token: string, maxOutputTokens: number = 20): string[] { - const outputTokens: string[] = []; - let isBad = false; - let start = 0; - let subToken: string | null = null; + private wordpiece_tokenize( + token: string, + maxOutputTokens: number = 20, + ): string[] { + const outputTokens: string[] = []; + let isBad = false; + let start = 0; + let subToken: string | null = null; - while (start < token.length && !isBad && outputTokens.length < maxOutputTokens) { - let found = false; + while ( + start < token.length && + !isBad && + outputTokens.length < maxOutputTokens + ) { + let found = false; - for (let end = token.length; end > start; end--) { - let substr = token.substring(start, end); - if (start > 0) { - substr = "##" + substr; - } + for (let end = token.length; end > start; end--) { + let substr = token.substring(start, end); + if (start > 0) { + substr = "##" + substr; + } - if (this.config.vocab.has(substr)) { - outputTokens.push(substr); - subToken = substr; - start = end; - found = true; - break; - } - } + if (this.config.vocab.has(substr)) { + outputTokens.push(substr); + subToken = substr; + start = end; + found = true; + break; + } + } - if (!found) { - isBad = true; - } - } + if (!found) { + isBad = true; + } + } - if (isBad) { - outputTokens.push("[UNK]"); - } else if (subToken === null) { - outputTokens.push("[UNK]"); - } + if (isBad) { + outputTokens.push("[UNK]"); + } else if (subToken === null) { + outputTokens.push("[UNK]"); + } - return outputTokens; - } + return outputTokens; + } - private tokenize(text: string): string[] { - const tokens = []; - const whitespaceTokens = this.whitespace_tokenize(text); + private tokenize(text: string): string[] { + const tokens = []; + const whitespaceTokens = this.whitespace_tokenize(text); - for (const token of whitespaceTokens) { - const subTokens = this.wordpiece_tokenize(token); - tokens.push(...subTokens); - } + for (const token of whitespaceTokens) { + const subTokens = this.wordpiece_tokenize(token); + tokens.push(...subTokens); + } - return tokens; - } + return tokens; + } - encode(text: string, maxLen: number = 128): { inputIds: number[]; attentionMask: number[] } { - const tokens = this.tokenize(text); + encode( + text: string, + maxLen: number = 128, + ): { inputIds: number[]; attentionMask: number[] } { + const tokens = this.tokenize(text); - // Add [CLS] and [SEP] - const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"]; + // Add [CLS] and [SEP] + const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"]; - const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK] - const attentionMask = new Array(inputIds.length).fill(1); + const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK] + const attentionMask = new Array(inputIds.length).fill(1); - // Pad to maxLen if needed - while (inputIds.length < maxLen) { - inputIds.push(0); - attentionMask.push(0); - } + // Pad to maxLen if needed + while (inputIds.length < maxLen) { + inputIds.push(0); + attentionMask.push(0); + } - return { inputIds, attentionMask }; - } + return { inputIds, attentionMask }; + } } // ── Model Loading ────────────────────────────────────────────────────────── const MODEL_DIR_ENV = "SPAM_MODEL_DIR"; -const DEFAULT_MODEL_DIR = path.join(__dirname, "..", "..", "models", "spam-classifier"); +const DEFAULT_MODEL_DIR = path.join( + __dirname, + "..", + "..", + "models", + "spam-classifier", +); function getModelDir(): string { - return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR; + return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR; +} + +// ── Remote Model Download ──────────────────────────────────────────────────── + +const MODEL_DOWNLOAD_URL_ENV = "SPAM_MODEL_URL_BASE"; + +/** Model files that need to be available in the model directory. */ +const MODEL_FILES = [ + "model.onnx", + "model.onnx.data", + "tokenizer.json", + "vocab.txt", + "tokenizer_config.json", + "special_tokens_map.json", + "model_metadata.json", +] as const; + +/** + * Check if all required model files exist in the given directory. + */ +function modelFilesExist(dir: string): boolean { + try { + return MODEL_FILES.every((f) => fs.existsSync(path.join(dir, f))); + } catch { + return false; + } +} + +/** + * Download a single model file from a remote URL to a local path. + * Uses streaming to handle large files (e.g., model.onnx.data at 255MB). + */ +async function downloadModelFile(url: string, destPath: string): Promise { + const response = await fetch(url); + if (!response.ok) { + throw new Error( + `Failed to download ${url}: ${response.status} ${response.statusText}`, + ); + } + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error(`No response body stream for ${url}`); + } + + // Ensure parent directory exists + const dir = path.dirname(destPath); + fs.mkdirSync(dir, { recursive: true }); + + // Stream to file + const writer = fs.createWriteStream(destPath); + try { + let totalBytes = 0; + let lastLog = 0; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + writer.write(value); + totalBytes += value.length; + + // Log progress every ~10MB + if (totalBytes - lastLog > 10 * 1024 * 1024) { + lastLog = totalBytes; + const mb = (totalBytes / (1024 * 1024)).toFixed(1); + console.log( + `[spamshield] Downloaded ${path.basename(destPath)}: ${mb}MB`, + ); + } + } + } finally { + writer.end(); + await new Promise((resolve) => writer.on("finish", resolve)); + } + + const totalMB = (fs.statSync(destPath).size / (1024 * 1024)).toFixed(1); + console.log( + `[spamshield] Downloaded ${path.basename(destPath)} (${totalMB}MB)`, + ); +} + +/** + * Download all model files from a remote URL base to the model directory. + * Falls back gracefully — if the URL is not configured, returns false. + */ +async function downloadModelIfMissing(modelDir: string): Promise { + // If model files already exist locally, nothing to do + if (modelFilesExist(modelDir)) { + return true; + } + + const baseUrl = process.env[MODEL_DOWNLOAD_URL_ENV]; + if (!baseUrl) { + console.log( + "[spamshield] Model files not found locally and SPAM_MODEL_URL_BASE not set — " + + "will use rule-engine fallback", + ); + return false; + } + + const normalizedBase = baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`; + console.log(`[spamshield] Downloading model from: ${normalizedBase}`); + + // Ensure model directory exists + fs.mkdirSync(modelDir, { recursive: true }); + + // Track which files we already have (for caching across cold starts) + const existing = new Set(); + for (const file of MODEL_FILES) { + const filePath = path.join(modelDir, file); + if (fs.existsSync(filePath) && fs.statSync(filePath).size > 0) { + existing.add(file); + } + } + + // Download missing files + for (const file of MODEL_FILES) { + if (existing.has(file)) { + console.log(`[spamshield] Already have ${file}, skipping download`); + continue; + } + const url = `${normalizedBase}${file}`; + const destPath = path.join(modelDir, file); + console.log(`[spamshield] Downloading ${file}...`); + try { + await downloadModelFile(url, destPath); + } catch (err) { + console.error(`[spamshield] Failed to download ${file}:`, err); + // If the main model files fail, we can't use the model + if (file === "model.onnx" || file === "model.onnx.data") { + throw err; + } + } + } + + return modelFilesExist(modelDir); } async function loadModel(): Promise { - if (modelState.loaded) return; + if (modelState.loaded) return; - try { - const modelDir = getModelDir(); - console.log(`[spamshield] Loading ONNX model from: ${modelDir}`); + try { + const modelDir = getModelDir(); + console.log(`[spamshield] Loading ONNX model from: ${modelDir}`); - // Load metadata - const metadataPath = path.join(modelDir, "model_metadata.json"); - if (fs.existsSync(metadataPath)) { - modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8")); - console.log(`[spamshield] Model version: ${modelState.metadata.version}`); - } + // Download model files if missing (production/Vercel path) + await downloadModelIfMissing(modelDir); - // Load tokenizer - modelState.tokenizer = new BertTokenizer(modelDir); - console.log("[spamshield] Tokenizer loaded"); + // Load metadata + const metadataPath = path.join(modelDir, "model_metadata.json"); + if (fs.existsSync(metadataPath)) { + modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8")); + console.log(`[spamshield] Model version: ${modelState.metadata.version}`); + } - // Load ONNX model - const modelPath = path.join(modelDir, "model.onnx"); - if (!fs.existsSync(modelPath)) { - // Check for external data file - const modelDataPath = path.join(modelDir, "model.onnx.data"); - if (!fs.existsSync(modelDataPath)) { - throw new Error(`ONNX model not found at ${modelPath}`); - } - } + // Load tokenizer + modelState.tokenizer = new BertTokenizer(modelDir); + console.log("[spamshield] Tokenizer loaded"); - modelState.session = await ort.InferenceSession.create(modelPath); - console.log("[spamshield] ONNX session created"); - console.log(`[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`); - console.log(`[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`); + // Load ONNX model + const modelPath = path.join(modelDir, "model.onnx"); + if (!fs.existsSync(modelPath)) { + // Check for external data file + const modelDataPath = path.join(modelDir, "model.onnx.data"); + if (!fs.existsSync(modelDataPath)) { + throw new Error(`ONNX model not found at ${modelPath}`); + } + } - modelState.loaded = true; - console.log("[spamshield] Model loaded successfully"); - } catch (err) { - modelState.loadError = err instanceof Error ? err : new Error(String(err)); - console.error("[spamshield] Failed to load ONNX model:", modelState.loadError); - console.log("[spamshield] Falling back to rule engine for classification"); - } + modelState.session = await ort.InferenceSession.create(modelPath); + console.log("[spamshield] ONNX session created"); + console.log( + `[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`, + ); + console.log( + `[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`, + ); + + modelState.loaded = true; + console.log("[spamshield] Model loaded successfully"); + } catch (err) { + modelState.loadError = err instanceof Error ? err : new Error(String(err)); + console.error( + "[spamshield] Failed to load ONNX model:", + modelState.loadError, + ); + console.log("[spamshield] Falling back to rule engine for classification"); + } } // ── Inference ────────────────────────────────────────────────────────────── function sigmoid(x: number): number { - return 1 / (1 + Math.exp(-x)); + return 1 / (1 + Math.exp(-x)); } async function runInference( - text: string, - thresholdMode: ThresholdMode = "moderate", + text: string, + thresholdMode: ThresholdMode = "moderate", ): Promise { - const thresholds = DEFAULT_THRESHOLDS; - const threshold = thresholds[thresholdMode]; + const thresholds = DEFAULT_THRESHOLDS; + const threshold = thresholds[thresholdMode]; - // Check cache first - const cached = getCached(text); - if (cached) { - return { ...cached, modelVersion: modelState.metadata.version }; - } + // Check cache first + const cached = getCached(text); + if (cached) { + return { ...cached, modelVersion: modelState.metadata.version }; + } - // Ensure model is loaded - if (!modelState.loaded || !modelState.session) { - await loadModel(); - } + // Ensure model is loaded + if (!modelState.loaded || !modelState.session) { + await loadModel(); + } - // If model still not loaded, return fallback - if (!modelState.loaded || !modelState.session) { - const fallback: TextClassification = { - isSpam: false, - confidence: 0, - score: 0, - modelVersion: "fallback", - }; - setCache(text, fallback); - return fallback; - } + // If model still not loaded, return fallback + if (!modelState.loaded || !modelState.session) { + const fallback: TextClassification = { + isSpam: false, + confidence: 0, + score: 0, + modelVersion: "fallback", + }; + setCache(text, fallback); + return fallback; + } - // Tokenize - const maxLen = modelState.metadata.max_length || 128; - const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen); + // Tokenize + const maxLen = modelState.metadata.max_length || 128; + const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen); - // Create ONNX tensors (int64 requires BigInt values) - const inputIdsBigInt = new BigInt64Array(inputIds.length); - for (let i = 0; i < inputIds.length; i++) { - inputIdsBigInt[i] = BigInt(inputIds[i]); - } - const attentionMaskBigInt = new BigInt64Array(attentionMask.length); - for (let i = 0; i < attentionMask.length; i++) { - attentionMaskBigInt[i] = BigInt(attentionMask[i]); - } + // Create ONNX tensors (int64 requires BigInt values) + const inputIdsBigInt = new BigInt64Array(inputIds.length); + for (let i = 0; i < inputIds.length; i++) { + inputIdsBigInt[i] = BigInt(inputIds[i]); + } + const attentionMaskBigInt = new BigInt64Array(attentionMask.length); + for (let i = 0; i < attentionMask.length; i++) { + attentionMaskBigInt[i] = BigInt(attentionMask[i]); + } - const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]); - const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [1, maxLen]); + const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]); + const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [ + 1, + maxLen, + ]); - // Run inference - const feeds: Record = { - input_ids: inputIdsTensor, - attention_mask: attentionMaskTensor, - }; + // Run inference + const feeds: Record = { + input_ids: inputIdsTensor, + attention_mask: attentionMaskTensor, + }; - const outputs = await modelState.session.run(feeds); - const logits = outputs[modelState.session.outputNames[0]]; + const outputs = await modelState.session.run(feeds); + const logits = outputs[modelState.session.outputNames[0]]; - // Extract logits (shape: [1, num_labels]) - const logitsData = logits.data as Float32Array | number[]; - const spamLogit = logitsData[1] ?? 0; - const hamLogit = logitsData[0] ?? 0; + // Extract logits (shape: [1, num_labels]) + const logitsData = logits.data as Float32Array | number[]; + const spamLogit = logitsData[1] ?? 0; + const hamLogit = logitsData[0] ?? 0; - // Apply sigmoid to get probability - const spamProb = sigmoid(spamLogit); - const hamProb = sigmoid(hamLogit); + // Apply sigmoid to get probability + const spamProb = sigmoid(spamLogit); + const hamProb = sigmoid(hamLogit); - // Binary decision based on threshold - const isSpam = spamProb >= threshold; - const confidence = isSpam ? spamProb : 1 - spamProb; + // Binary decision based on threshold + const isSpam = spamProb >= threshold; + const confidence = isSpam ? spamProb : 1 - spamProb; - const result: TextClassification = { - isSpam, - confidence: Math.round(confidence * 10000) / 10000, - score: Math.round(spamProb * 10000) / 10000, - modelVersion: modelState.metadata.version, - }; + const result: TextClassification = { + isSpam, + confidence: Math.round(confidence * 10000) / 10000, + score: Math.round(spamProb * 10000) / 10000, + modelVersion: modelState.metadata.version, + }; - setCache(text, result); - return result; + setCache(text, result); + return result; } // ── Public API ───────────────────────────────────────────────────────────── @@ -393,21 +559,21 @@ async function runInference( * Falls back to returning a safe default if the model fails to load. */ export async function classifyTextBERT( - text: string, - thresholdMode: ThresholdMode = "moderate", + text: string, + thresholdMode: ThresholdMode = "moderate", ): Promise { - try { - return await runInference(text, thresholdMode); - } catch (err) { - console.error("[spamshield] ONNX inference error:", err); - // Graceful fallback: return non-spam with low confidence - return { - isSpam: false, - confidence: 0, - score: 0, - modelVersion: "error", - }; - } + try { + return await runInference(text, thresholdMode); + } catch (err) { + console.error("[spamshield] ONNX inference error:", err); + // Graceful fallback: return non-spam with low confidence + return { + isSpam: false, + confidence: 0, + score: 0, + modelVersion: "error", + }; + } } /** @@ -415,41 +581,41 @@ export async function classifyTextBERT( * Call this once during server initialization. */ export async function initSpamModel(): Promise { - await loadModel(); - return modelState.loaded; + await loadModel(); + return modelState.loaded; } /** * Check if the model is loaded and ready. */ export function isModelLoaded(): boolean { - return modelState.loaded && modelState.session !== null; + return modelState.loaded && modelState.session !== null; } /** * Get model metadata. */ export function getModelInfo(): ModelMetadata { - return { ...modelState.metadata }; + return { ...modelState.metadata }; } /** * Get the current cache stats. */ export function getCacheStats(): { size: number; max: number } { - return { size: resultCache.size, max: CACHE_MAX_SIZE }; + return { size: resultCache.size, max: CACHE_MAX_SIZE }; } /** * Clear the result cache. */ export function clearCache(): void { - resultCache.clear(); + resultCache.clear(); } /** * Get available threshold modes and their values. */ export function getThresholds(): ClassificationThresholds { - return { ...DEFAULT_THRESHOLDS }; + return { ...DEFAULT_THRESHOLDS }; } diff --git a/web/vercel.json b/web/vercel.json new file mode 100644 index 0000000..ef14208 --- /dev/null +++ b/web/vercel.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://openapi.vercel.sh/vercel.json", + "framework": "solidstart", + "buildCommand": "npm run build", + "installCommand": "npm install", + "outputDirectory": ".output/public", + "regions": ["iad1"] +}