# SMS Spam Classifier - DistilBERT ONNX Pipeline ## Overview This directory contains the training pipeline for a fine-tuned DistilBERT SMS spam classifier, exported to ONNX format for fast CPU inference in the Kordant spamshield service. ## Architecture ``` ┌─────────────────────────────────────────────────────────────┐ │ Training (Python) │ │ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │ │ │ SMS Spam │ │ DistilBERT │ │ ONNX Export │ │ │ │ Collection │→ │ Fine-tuning │→ │ + INT8 Quantize │ │ │ │ Dataset │ │ (3 epochs) │ │ │ │ │ └─────────────┘ └──────────────┘ └───────────────────┘ │ └─────────────────────────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────────────────────────┐ │ Inference (Node.js) │ │ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │ │ │ Text Input │ │ BertTokenizer│ │ ONNX Runtime │ │ │ │ │→ │ (JS impl) │→ │ (onnxruntime-node)│ │ │ └─────────────┘ └──────────────┘ └───────────────────┘ │ │ ↓ │ │ ┌─────────────────────────────────────────────────────┐ │ │ │ Result: { isSpam, confidence, score, modelVersion } │ │ │ └─────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────┘ ``` ## Training ### Prerequisites - Python 3.10+ (tested with 3.14) - Virtual environment: `python3 -m venv .venv-ml && source .venv-ml/bin/activate` - Dependencies: `pip install transformers datasets torch optimum[onnxruntime] onnxruntime` ### Run Training ```bash # Full pipeline with data augmentation python ml/spam-classifier/train.py \ --output-dir ml/spam-classifier/output \ --augment \ --epochs 3 \ --batch-size 32 # Skip training (use existing model for export only) python ml/spam-classifier/train.py \ --output-dir ml/spam-classifier/output \ --skip-training # With INT8 quantization python ml/spam-classifier/train.py \ --output-dir ml/spam-classifier/output \ --augment \ --quantize ``` ### Training Parameters | Parameter | Default | Description | |-----------|---------|-------------| | `--epochs` | 3 | Number of training epochs | | `--batch-size` | 32 | Training batch size | | `--lr` | 2e-5 | Learning rate | | `--augment` | false | Include synthetic data augmentation | | `--quantize` | false | Quantize to INT8 for smaller model | | `--skip-training` | false | Skip training, only export | ### Expected Results - **Validation accuracy**: 97-99% on SMS Spam Collection - **Training time**: ~5-10 minutes on M-series Mac, ~2-3 minutes on GPU - **Model size**: ~257MB (FP32), ~128MB (INT8 quantized) ## Inference ### Setup The trained ONNX model is deployed to `web/src/server/models/spam-classifier/`. The Node.js inference wrapper is in `web/src/server/services/spamshield/onnx.inference.ts`. ### Usage ```typescript import { classifyTextBERT, initSpamModel } from "~/server/services/spamshield/onnx.inference"; // Initialize at startup (loads model once) await initSpamModel(); // Classify text const result = await classifyTextBERT("Hello world"); // { isSpam: false, confidence: 0.97, score: 0.03, modelVersion: "1.0.0" } // With threshold mode const strict = await classifyTextBERT("Get free money!", "strict"); const lenient = await classifyTextBERT("Get free money!", "lenient"); ``` ### Threshold Modes | Mode | Threshold | Description | |------|-----------|-------------| | `strict` | 0.3 | Aggressive spam detection, more false positives | | `moderate` | 0.5 | Balanced (default) | | `lenient` | 0.7 | Conservative, fewer false positives | ### Performance | Metric | Value | |--------|-------| | p50 latency | <5ms | | p95 latency | <30ms | | p99 latency | <50ms | | Throughput | ~240 inferences/sec | | Model load time | ~2s (cold start) | ### Benchmark ```bash cd web pnpm benchmark:spamshield --iterations 1000 ``` ## Model Versioning The model version is tracked in `model_metadata.json` alongside the ONNX files. To deploy a new model: 1. Run training with new data/hyperparameters 2. Update version in `model_metadata.json` 3. Copy ONNX files to `web/src/server/models/spam-classifier/` 4. Deploy ## Feedback Loop User feedback is stored in the `spam_feedback` database table and can be used for periodic retraining: ```sql SELECT text, is_spam, feedback_type FROM spam_feedback WHERE feedback_type = 'user_rejection' ORDER BY created_at DESC; ``` ## Files ``` ml/spam-classifier/ ├── train.py # Training pipeline ├── README.md # This file └── output/ # Training output (gitignored except onnx_model) ├── data/ # Raw dataset (gitignored) ├── final_model/ # PyTorch model (gitignored) ├── best_model/ # Best checkpoint (gitignored) └── onnx_model/ # ONNX export (copied to web/) ├── model.onnx ├── model.onnx.data ├── vocab.txt ├── tokenizer.json ├── tokenizer_config.json └── model_metadata.json ```