onnx, fix depl issue

This commit is contained in:
2026-06-03 13:35:37 -04:00
parent 1408d0cd1d
commit 61d48d3648
5 changed files with 511 additions and 285 deletions

View File

@@ -21,7 +21,7 @@ Status legend: [ ] todo, [~] in-progress, [x] done
### Performance Optimization ### Performance Optimization
- [x] 09 — Image Caching & Lazy Loading → `09-image-caching.md` - [x] 09 — Image Caching & Lazy Loading → `09-image-caching.md`
- [x] 10 — Memory Management & Leak Audit → `10-memory-leak-audit.md` - [x] 10 — Memory Management & Leak Audit → `10-memory-leak-audit.md`
- [~] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md` - [x] 11 — Background Fetch & Sync Optimization → `11-background-fetch.md`
- [x] 12 — App Launch Time Optimization → `12-launch-time.md` - [x] 12 — App Launch Time Optimization → `12-launch-time.md`
### Native Features ### Native Features

1
web/.gitignore vendored
View File

@@ -5,6 +5,7 @@ dist
.netlify .netlify
.vinxi .vinxi
app.config.timestamp_*.js app.config.timestamp_*.js
.pi-lens
# Environment # Environment
.env* .env*

51
web/.vercelignore Normal file
View File

@@ -0,0 +1,51 @@
# ── ML Model (255MB ONNX model — too large for Vercel, downloaded at runtime) ──
src/server/models/spam-classifier/
# ── Build Artifacts ──
.output/
.nitro/
dist/
# ── Test Files (not needed in production) ──
e2e/
test/
**/*.test.ts
**/*.test.tsx
**/*.spec.ts
**/*.spec.tsx
# ── Development / Config ──
.dockerignore
Dockerfile
docker-compose.yml
docker-compose.yaml
vitest.config.ts
vitest.node.config.ts
playwright.config.ts
drizzle.config.ts
drizzle/
# ── Version Control ──
.git/
.gitignore
.github/
.husky/
# ── Environment (already in .gitignore, being explicit) ──
.env
.env.development
.env.production
.env.local
# ── Editors / OS ──
.idea/
.vscode/
*.swp
*.swo
*~
.DS_Store
Thumbs.db
# ── Pi agent / dev tooling ──
.pi-lens/
.agents/

View File

@@ -17,67 +17,67 @@ const __dirname = path.dirname(__filename);
// ── Types ────────────────────────────────────────────────────────────────── // ── Types ──────────────────────────────────────────────────────────────────
export interface TextClassification { export interface TextClassification {
isSpam: boolean; isSpam: boolean;
confidence: number; confidence: number;
score: number; score: number;
modelVersion?: string; modelVersion?: string;
} }
export interface ClassificationThresholds { export interface ClassificationThresholds {
strict: number; // 0.3 - flag more aggressively strict: number; // 0.3 - flag more aggressively
moderate: number; // 0.5 - balanced moderate: number; // 0.5 - balanced
lenient: number; // 0.7 - fewer false positives lenient: number; // 0.7 - fewer false positives
} }
export type ThresholdMode = "strict" | "moderate" | "lenient"; export type ThresholdMode = "strict" | "moderate" | "lenient";
const DEFAULT_THRESHOLDS: ClassificationThresholds = { const DEFAULT_THRESHOLDS: ClassificationThresholds = {
strict: 0.3, strict: 0.3,
moderate: 0.5, moderate: 0.5,
lenient: 0.7, lenient: 0.7,
}; };
// ── Model Singleton ──────────────────────────────────────────────────────── // ── Model Singleton ────────────────────────────────────────────────────────
interface ModelState { interface ModelState {
session: InferenceSession | null; session: InferenceSession | null;
tokenizer: BertTokenizer; tokenizer: BertTokenizer;
metadata: ModelMetadata; metadata: ModelMetadata;
loaded: boolean; loaded: boolean;
loadError: Error | null; loadError: Error | null;
} }
interface ModelMetadata { interface ModelMetadata {
version: string; version: string;
model_name: string; model_name: string;
task: string; task: string;
max_length: number; max_length: number;
num_labels: number; num_labels: number;
label2id: Record<string, number>; label2id: Record<string, number>;
id2label: Record<number, string>; id2label: Record<number, string>;
} }
const modelState: ModelState = { const modelState: ModelState = {
session: null, session: null,
tokenizer: null as unknown as BertTokenizer, tokenizer: null as unknown as BertTokenizer,
metadata: { metadata: {
version: "0.0.0", version: "0.0.0",
model_name: "", model_name: "",
task: "", task: "",
max_length: 128, max_length: 128,
num_labels: 2, num_labels: 2,
label2id: {}, label2id: {},
id2label: {}, id2label: {},
}, },
loaded: false, loaded: false,
loadError: null, loadError: null,
}; };
// ── Result Cache ─────────────────────────────────────────────────────────── // ── Result Cache ───────────────────────────────────────────────────────────
interface CacheEntry { interface CacheEntry {
result: TextClassification; result: TextClassification;
timestamp: number; timestamp: number;
} }
const resultCache = new Map<string, CacheEntry>(); const resultCache = new Map<string, CacheEntry>();
@@ -85,305 +85,471 @@ const CACHE_MAX_SIZE = 1000;
const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes
function cacheKey(text: string): string { function cacheKey(text: string): string {
// Simple hash of normalized text // Simple hash of normalized text
const normalized = text.toLowerCase().trim(); const normalized = text.toLowerCase().trim();
let hash = 0; let hash = 0;
for (let i = 0; i < normalized.length; i++) { for (let i = 0; i < normalized.length; i++) {
const char = normalized.charCodeAt(i); const char = normalized.charCodeAt(i);
hash = ((hash << 5) - hash) + char; hash = (hash << 5) - hash + char;
hash |= 0; // Convert to 32bit integer hash |= 0; // Convert to 32bit integer
} }
return String(hash); return String(hash);
} }
function getCached(text: string): TextClassification | null { function getCached(text: string): TextClassification | null {
const key = cacheKey(text); const key = cacheKey(text);
const entry = resultCache.get(key); const entry = resultCache.get(key);
if (!entry) return null; if (!entry) return null;
if (Date.now() - entry.timestamp > CACHE_TTL_MS) { if (Date.now() - entry.timestamp > CACHE_TTL_MS) {
resultCache.delete(key); resultCache.delete(key);
return null; return null;
} }
return entry.result; return entry.result;
} }
function setCache(text: string, result: TextClassification): void { function setCache(text: string, result: TextClassification): void {
if (resultCache.size >= CACHE_MAX_SIZE) { if (resultCache.size >= CACHE_MAX_SIZE) {
// Evict oldest entry // Evict oldest entry
const oldestKey = resultCache.keys().next().value; const oldestKey = resultCache.keys().next().value;
resultCache.delete(oldestKey); resultCache.delete(oldestKey);
} }
resultCache.set(cacheKey(text), { result, timestamp: Date.now() }); resultCache.set(cacheKey(text), { result, timestamp: Date.now() });
} }
// ── BertTokenizer (JavaScript implementation) ────────────────────────────── // ── BertTokenizer (JavaScript implementation) ──────────────────────────────
interface TokenizerConfig { interface TokenizerConfig {
vocab: Map<string, number>; vocab: Map<string, number>;
inv_vocab: Map<number, string>; inv_vocab: Map<number, string>;
max_len: number; max_len: number;
do_lower_case: boolean; do_lower_case: boolean;
tokenizers: Record<string, unknown>; tokenizers: Record<string, unknown>;
model_max_length: number; model_max_length: number;
} }
class BertTokenizer { class BertTokenizer {
private config: TokenizerConfig; private config: TokenizerConfig;
constructor(configPath: string) { constructor(configPath: string) {
this.config = this.loadConfig(configPath); this.config = this.loadConfig(configPath);
} }
private loadConfig(configPath: string): TokenizerConfig { private loadConfig(configPath: string): TokenizerConfig {
const vocabPath = path.join(configPath, "vocab.txt"); const vocabPath = path.join(configPath, "vocab.txt");
const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json"); const tokenizerConfigPath = path.join(configPath, "tokenizer_config.json");
// Load vocabulary // Load vocabulary
const vocab = new Map<string, number>(); const vocab = new Map<string, number>();
const inv_vocab = new Map<number, string>(); const inv_vocab = new Map<number, string>();
const vocabText = fs.readFileSync(vocabPath, "utf-8"); const vocabText = fs.readFileSync(vocabPath, "utf-8");
const lines = vocabText.split("\n"); const lines = vocabText.split("\n");
for (let i = 0; i < lines.length; i++) { for (let i = 0; i < lines.length; i++) {
const token = lines[i].trim(); const token = lines[i].trim();
if (token) { if (token) {
vocab.set(token, i); vocab.set(token, i);
inv_vocab.set(i, token); inv_vocab.set(i, token);
} }
} }
// Load tokenizer config // Load tokenizer config
let doLowercase = true; let doLowercase = true;
let modelMaxLength = 512; let modelMaxLength = 512;
try { try {
const configData = JSON.parse(fs.readFileSync(tokenizerConfigPath, "utf-8")); const configData = JSON.parse(
doLowercase = configData.do_lower_case ?? true; fs.readFileSync(tokenizerConfigPath, "utf-8"),
modelMaxLength = configData.model_max_length ?? 512; );
} catch { doLowercase = configData.do_lower_case ?? true;
// Use defaults modelMaxLength = configData.model_max_length ?? 512;
} } catch {
// Use defaults
}
return { return {
vocab, vocab,
inv_vocab, inv_vocab,
max_len: 512, max_len: 512,
do_lower_case: doLowercase, do_lower_case: doLowercase,
tokenizers: {}, tokenizers: {},
model_max_length: modelMaxLength, model_max_length: modelMaxLength,
}; };
} }
private whitespace_tokenize(text: string): string[] { private whitespace_tokenize(text: string): string[] {
if (this.config.do_lower_case) { if (this.config.do_lower_case) {
text = text.toLowerCase(); text = text.toLowerCase();
} }
// Split on whitespace, keeping punctuation attached // Split on whitespace, keeping punctuation attached
return text.split(/\s+/).filter((t) => t.length > 0); return text.split(/\s+/).filter((t) => t.length > 0);
} }
private wordpiece_tokenize(token: string, maxOutputTokens: number = 20): string[] { private wordpiece_tokenize(
const outputTokens: string[] = []; token: string,
let isBad = false; maxOutputTokens: number = 20,
let start = 0; ): string[] {
let subToken: string | null = null; const outputTokens: string[] = [];
let isBad = false;
let start = 0;
let subToken: string | null = null;
while (start < token.length && !isBad && outputTokens.length < maxOutputTokens) { while (
let found = false; start < token.length &&
!isBad &&
outputTokens.length < maxOutputTokens
) {
let found = false;
for (let end = token.length; end > start; end--) { for (let end = token.length; end > start; end--) {
let substr = token.substring(start, end); let substr = token.substring(start, end);
if (start > 0) { if (start > 0) {
substr = "##" + substr; substr = "##" + substr;
} }
if (this.config.vocab.has(substr)) { if (this.config.vocab.has(substr)) {
outputTokens.push(substr); outputTokens.push(substr);
subToken = substr; subToken = substr;
start = end; start = end;
found = true; found = true;
break; break;
} }
} }
if (!found) { if (!found) {
isBad = true; isBad = true;
} }
} }
if (isBad) { if (isBad) {
outputTokens.push("[UNK]"); outputTokens.push("[UNK]");
} else if (subToken === null) { } else if (subToken === null) {
outputTokens.push("[UNK]"); outputTokens.push("[UNK]");
} }
return outputTokens; return outputTokens;
} }
private tokenize(text: string): string[] { private tokenize(text: string): string[] {
const tokens = []; const tokens = [];
const whitespaceTokens = this.whitespace_tokenize(text); const whitespaceTokens = this.whitespace_tokenize(text);
for (const token of whitespaceTokens) { for (const token of whitespaceTokens) {
const subTokens = this.wordpiece_tokenize(token); const subTokens = this.wordpiece_tokenize(token);
tokens.push(...subTokens); tokens.push(...subTokens);
} }
return tokens; return tokens;
} }
encode(text: string, maxLen: number = 128): { inputIds: number[]; attentionMask: number[] } { encode(
const tokens = this.tokenize(text); text: string,
maxLen: number = 128,
): { inputIds: number[]; attentionMask: number[] } {
const tokens = this.tokenize(text);
// Add [CLS] and [SEP] // Add [CLS] and [SEP]
const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"]; const allTokens = ["[CLS]", ...tokens.slice(0, maxLen - 2), "[SEP]"];
const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK] const inputIds = allTokens.map((t) => this.config.vocab.get(t) ?? 100); // 100 = [UNK]
const attentionMask = new Array(inputIds.length).fill(1); const attentionMask = new Array(inputIds.length).fill(1);
// Pad to maxLen if needed // Pad to maxLen if needed
while (inputIds.length < maxLen) { while (inputIds.length < maxLen) {
inputIds.push(0); inputIds.push(0);
attentionMask.push(0); attentionMask.push(0);
} }
return { inputIds, attentionMask }; return { inputIds, attentionMask };
} }
} }
// ── Model Loading ────────────────────────────────────────────────────────── // ── Model Loading ──────────────────────────────────────────────────────────
const MODEL_DIR_ENV = "SPAM_MODEL_DIR"; const MODEL_DIR_ENV = "SPAM_MODEL_DIR";
const DEFAULT_MODEL_DIR = path.join(__dirname, "..", "..", "models", "spam-classifier"); const DEFAULT_MODEL_DIR = path.join(
__dirname,
"..",
"..",
"models",
"spam-classifier",
);
function getModelDir(): string { function getModelDir(): string {
return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR; return process.env[MODEL_DIR_ENV] || DEFAULT_MODEL_DIR;
}
// ── Remote Model Download ────────────────────────────────────────────────────
const MODEL_DOWNLOAD_URL_ENV = "SPAM_MODEL_URL_BASE";
/** Model files that need to be available in the model directory. */
const MODEL_FILES = [
"model.onnx",
"model.onnx.data",
"tokenizer.json",
"vocab.txt",
"tokenizer_config.json",
"special_tokens_map.json",
"model_metadata.json",
] as const;
/**
* Check if all required model files exist in the given directory.
*/
function modelFilesExist(dir: string): boolean {
try {
return MODEL_FILES.every((f) => fs.existsSync(path.join(dir, f)));
} catch {
return false;
}
}
/**
* Download a single model file from a remote URL to a local path.
* Uses streaming to handle large files (e.g., model.onnx.data at 255MB).
*/
async function downloadModelFile(url: string, destPath: string): Promise<void> {
const response = await fetch(url);
if (!response.ok) {
throw new Error(
`Failed to download ${url}: ${response.status} ${response.statusText}`,
);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error(`No response body stream for ${url}`);
}
// Ensure parent directory exists
const dir = path.dirname(destPath);
fs.mkdirSync(dir, { recursive: true });
// Stream to file
const writer = fs.createWriteStream(destPath);
try {
let totalBytes = 0;
let lastLog = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
writer.write(value);
totalBytes += value.length;
// Log progress every ~10MB
if (totalBytes - lastLog > 10 * 1024 * 1024) {
lastLog = totalBytes;
const mb = (totalBytes / (1024 * 1024)).toFixed(1);
console.log(
`[spamshield] Downloaded ${path.basename(destPath)}: ${mb}MB`,
);
}
}
} finally {
writer.end();
await new Promise<void>((resolve) => writer.on("finish", resolve));
}
const totalMB = (fs.statSync(destPath).size / (1024 * 1024)).toFixed(1);
console.log(
`[spamshield] Downloaded ${path.basename(destPath)} (${totalMB}MB)`,
);
}
/**
* Download all model files from a remote URL base to the model directory.
* Falls back gracefully — if the URL is not configured, returns false.
*/
async function downloadModelIfMissing(modelDir: string): Promise<boolean> {
// If model files already exist locally, nothing to do
if (modelFilesExist(modelDir)) {
return true;
}
const baseUrl = process.env[MODEL_DOWNLOAD_URL_ENV];
if (!baseUrl) {
console.log(
"[spamshield] Model files not found locally and SPAM_MODEL_URL_BASE not set — " +
"will use rule-engine fallback",
);
return false;
}
const normalizedBase = baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`;
console.log(`[spamshield] Downloading model from: ${normalizedBase}`);
// Ensure model directory exists
fs.mkdirSync(modelDir, { recursive: true });
// Track which files we already have (for caching across cold starts)
const existing = new Set<string>();
for (const file of MODEL_FILES) {
const filePath = path.join(modelDir, file);
if (fs.existsSync(filePath) && fs.statSync(filePath).size > 0) {
existing.add(file);
}
}
// Download missing files
for (const file of MODEL_FILES) {
if (existing.has(file)) {
console.log(`[spamshield] Already have ${file}, skipping download`);
continue;
}
const url = `${normalizedBase}${file}`;
const destPath = path.join(modelDir, file);
console.log(`[spamshield] Downloading ${file}...`);
try {
await downloadModelFile(url, destPath);
} catch (err) {
console.error(`[spamshield] Failed to download ${file}:`, err);
// If the main model files fail, we can't use the model
if (file === "model.onnx" || file === "model.onnx.data") {
throw err;
}
}
}
return modelFilesExist(modelDir);
} }
async function loadModel(): Promise<void> { async function loadModel(): Promise<void> {
if (modelState.loaded) return; if (modelState.loaded) return;
try { try {
const modelDir = getModelDir(); const modelDir = getModelDir();
console.log(`[spamshield] Loading ONNX model from: ${modelDir}`); console.log(`[spamshield] Loading ONNX model from: ${modelDir}`);
// Load metadata // Download model files if missing (production/Vercel path)
const metadataPath = path.join(modelDir, "model_metadata.json"); await downloadModelIfMissing(modelDir);
if (fs.existsSync(metadataPath)) {
modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8"));
console.log(`[spamshield] Model version: ${modelState.metadata.version}`);
}
// Load tokenizer // Load metadata
modelState.tokenizer = new BertTokenizer(modelDir); const metadataPath = path.join(modelDir, "model_metadata.json");
console.log("[spamshield] Tokenizer loaded"); if (fs.existsSync(metadataPath)) {
modelState.metadata = JSON.parse(fs.readFileSync(metadataPath, "utf-8"));
console.log(`[spamshield] Model version: ${modelState.metadata.version}`);
}
// Load ONNX model // Load tokenizer
const modelPath = path.join(modelDir, "model.onnx"); modelState.tokenizer = new BertTokenizer(modelDir);
if (!fs.existsSync(modelPath)) { console.log("[spamshield] Tokenizer loaded");
// Check for external data file
const modelDataPath = path.join(modelDir, "model.onnx.data");
if (!fs.existsSync(modelDataPath)) {
throw new Error(`ONNX model not found at ${modelPath}`);
}
}
modelState.session = await ort.InferenceSession.create(modelPath); // Load ONNX model
console.log("[spamshield] ONNX session created"); const modelPath = path.join(modelDir, "model.onnx");
console.log(`[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`); if (!fs.existsSync(modelPath)) {
console.log(`[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`); // Check for external data file
const modelDataPath = path.join(modelDir, "model.onnx.data");
if (!fs.existsSync(modelDataPath)) {
throw new Error(`ONNX model not found at ${modelPath}`);
}
}
modelState.loaded = true; modelState.session = await ort.InferenceSession.create(modelPath);
console.log("[spamshield] Model loaded successfully"); console.log("[spamshield] ONNX session created");
} catch (err) { console.log(
modelState.loadError = err instanceof Error ? err : new Error(String(err)); `[spamshield] Inputs: ${modelState.session.inputNames.join(", ")}`,
console.error("[spamshield] Failed to load ONNX model:", modelState.loadError); );
console.log("[spamshield] Falling back to rule engine for classification"); console.log(
} `[spamshield] Outputs: ${modelState.session.outputNames.join(", ")}`,
);
modelState.loaded = true;
console.log("[spamshield] Model loaded successfully");
} catch (err) {
modelState.loadError = err instanceof Error ? err : new Error(String(err));
console.error(
"[spamshield] Failed to load ONNX model:",
modelState.loadError,
);
console.log("[spamshield] Falling back to rule engine for classification");
}
} }
// ── Inference ────────────────────────────────────────────────────────────── // ── Inference ──────────────────────────────────────────────────────────────
function sigmoid(x: number): number { function sigmoid(x: number): number {
return 1 / (1 + Math.exp(-x)); return 1 / (1 + Math.exp(-x));
} }
async function runInference( async function runInference(
text: string, text: string,
thresholdMode: ThresholdMode = "moderate", thresholdMode: ThresholdMode = "moderate",
): Promise<TextClassification> { ): Promise<TextClassification> {
const thresholds = DEFAULT_THRESHOLDS; const thresholds = DEFAULT_THRESHOLDS;
const threshold = thresholds[thresholdMode]; const threshold = thresholds[thresholdMode];
// Check cache first // Check cache first
const cached = getCached(text); const cached = getCached(text);
if (cached) { if (cached) {
return { ...cached, modelVersion: modelState.metadata.version }; return { ...cached, modelVersion: modelState.metadata.version };
} }
// Ensure model is loaded // Ensure model is loaded
if (!modelState.loaded || !modelState.session) { if (!modelState.loaded || !modelState.session) {
await loadModel(); await loadModel();
} }
// If model still not loaded, return fallback // If model still not loaded, return fallback
if (!modelState.loaded || !modelState.session) { if (!modelState.loaded || !modelState.session) {
const fallback: TextClassification = { const fallback: TextClassification = {
isSpam: false, isSpam: false,
confidence: 0, confidence: 0,
score: 0, score: 0,
modelVersion: "fallback", modelVersion: "fallback",
}; };
setCache(text, fallback); setCache(text, fallback);
return fallback; return fallback;
} }
// Tokenize // Tokenize
const maxLen = modelState.metadata.max_length || 128; const maxLen = modelState.metadata.max_length || 128;
const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen); const { inputIds, attentionMask } = modelState.tokenizer.encode(text, maxLen);
// Create ONNX tensors (int64 requires BigInt values) // Create ONNX tensors (int64 requires BigInt values)
const inputIdsBigInt = new BigInt64Array(inputIds.length); const inputIdsBigInt = new BigInt64Array(inputIds.length);
for (let i = 0; i < inputIds.length; i++) { for (let i = 0; i < inputIds.length; i++) {
inputIdsBigInt[i] = BigInt(inputIds[i]); inputIdsBigInt[i] = BigInt(inputIds[i]);
} }
const attentionMaskBigInt = new BigInt64Array(attentionMask.length); const attentionMaskBigInt = new BigInt64Array(attentionMask.length);
for (let i = 0; i < attentionMask.length; i++) { for (let i = 0; i < attentionMask.length; i++) {
attentionMaskBigInt[i] = BigInt(attentionMask[i]); attentionMaskBigInt[i] = BigInt(attentionMask[i]);
} }
const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]); const inputIdsTensor = new ort.Tensor("int64", inputIdsBigInt, [1, maxLen]);
const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [1, maxLen]); const attentionMaskTensor = new ort.Tensor("int64", attentionMaskBigInt, [
1,
maxLen,
]);
// Run inference // Run inference
const feeds: Record<string, Tensor> = { const feeds: Record<string, Tensor> = {
input_ids: inputIdsTensor, input_ids: inputIdsTensor,
attention_mask: attentionMaskTensor, attention_mask: attentionMaskTensor,
}; };
const outputs = await modelState.session.run(feeds); const outputs = await modelState.session.run(feeds);
const logits = outputs[modelState.session.outputNames[0]]; const logits = outputs[modelState.session.outputNames[0]];
// Extract logits (shape: [1, num_labels]) // Extract logits (shape: [1, num_labels])
const logitsData = logits.data as Float32Array | number[]; const logitsData = logits.data as Float32Array | number[];
const spamLogit = logitsData[1] ?? 0; const spamLogit = logitsData[1] ?? 0;
const hamLogit = logitsData[0] ?? 0; const hamLogit = logitsData[0] ?? 0;
// Apply sigmoid to get probability // Apply sigmoid to get probability
const spamProb = sigmoid(spamLogit); const spamProb = sigmoid(spamLogit);
const hamProb = sigmoid(hamLogit); const hamProb = sigmoid(hamLogit);
// Binary decision based on threshold // Binary decision based on threshold
const isSpam = spamProb >= threshold; const isSpam = spamProb >= threshold;
const confidence = isSpam ? spamProb : 1 - spamProb; const confidence = isSpam ? spamProb : 1 - spamProb;
const result: TextClassification = { const result: TextClassification = {
isSpam, isSpam,
confidence: Math.round(confidence * 10000) / 10000, confidence: Math.round(confidence * 10000) / 10000,
score: Math.round(spamProb * 10000) / 10000, score: Math.round(spamProb * 10000) / 10000,
modelVersion: modelState.metadata.version, modelVersion: modelState.metadata.version,
}; };
setCache(text, result); setCache(text, result);
return result; return result;
} }
// ── Public API ───────────────────────────────────────────────────────────── // ── Public API ─────────────────────────────────────────────────────────────
@@ -393,21 +559,21 @@ async function runInference(
* Falls back to returning a safe default if the model fails to load. * Falls back to returning a safe default if the model fails to load.
*/ */
export async function classifyTextBERT( export async function classifyTextBERT(
text: string, text: string,
thresholdMode: ThresholdMode = "moderate", thresholdMode: ThresholdMode = "moderate",
): Promise<TextClassification> { ): Promise<TextClassification> {
try { try {
return await runInference(text, thresholdMode); return await runInference(text, thresholdMode);
} catch (err) { } catch (err) {
console.error("[spamshield] ONNX inference error:", err); console.error("[spamshield] ONNX inference error:", err);
// Graceful fallback: return non-spam with low confidence // Graceful fallback: return non-spam with low confidence
return { return {
isSpam: false, isSpam: false,
confidence: 0, confidence: 0,
score: 0, score: 0,
modelVersion: "error", modelVersion: "error",
}; };
} }
} }
/** /**
@@ -415,41 +581,41 @@ export async function classifyTextBERT(
* Call this once during server initialization. * Call this once during server initialization.
*/ */
export async function initSpamModel(): Promise<boolean> { export async function initSpamModel(): Promise<boolean> {
await loadModel(); await loadModel();
return modelState.loaded; return modelState.loaded;
} }
/** /**
* Check if the model is loaded and ready. * Check if the model is loaded and ready.
*/ */
export function isModelLoaded(): boolean { export function isModelLoaded(): boolean {
return modelState.loaded && modelState.session !== null; return modelState.loaded && modelState.session !== null;
} }
/** /**
* Get model metadata. * Get model metadata.
*/ */
export function getModelInfo(): ModelMetadata { export function getModelInfo(): ModelMetadata {
return { ...modelState.metadata }; return { ...modelState.metadata };
} }
/** /**
* Get the current cache stats. * Get the current cache stats.
*/ */
export function getCacheStats(): { size: number; max: number } { export function getCacheStats(): { size: number; max: number } {
return { size: resultCache.size, max: CACHE_MAX_SIZE }; return { size: resultCache.size, max: CACHE_MAX_SIZE };
} }
/** /**
* Clear the result cache. * Clear the result cache.
*/ */
export function clearCache(): void { export function clearCache(): void {
resultCache.clear(); resultCache.clear();
} }
/** /**
* Get available threshold modes and their values. * Get available threshold modes and their values.
*/ */
export function getThresholds(): ClassificationThresholds { export function getThresholds(): ClassificationThresholds {
return { ...DEFAULT_THRESHOLDS }; return { ...DEFAULT_THRESHOLDS };
} }

8
web/vercel.json Normal file
View File

@@ -0,0 +1,8 @@
{
"$schema": "https://openapi.vercel.sh/vercel.json",
"framework": "solidstart",
"buildCommand": "npm run build",
"installCommand": "npm install",
"outputDirectory": ".output/public",
"regions": ["iad1"]
}