Files
Kordant/ml/spam-classifier/README.md

167 lines
6.4 KiB
Markdown

# SMS Spam Classifier - DistilBERT ONNX Pipeline
## Overview
This directory contains the training pipeline for a fine-tuned DistilBERT SMS spam classifier, exported to ONNX format for fast CPU inference in the Kordant spamshield service.
## Architecture
```
┌─────────────────────────────────────────────────────────────┐
│ Training (Python) │
│ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │
│ │ SMS Spam │ │ DistilBERT │ │ ONNX Export │ │
│ │ Collection │→ │ Fine-tuning │→ │ + INT8 Quantize │ │
│ │ Dataset │ │ (3 epochs) │ │ │ │
│ └─────────────┘ └──────────────┘ └───────────────────┘ │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Inference (Node.js) │
│ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │
│ │ Text Input │ │ BertTokenizer│ │ ONNX Runtime │ │
│ │ │→ │ (JS impl) │→ │ (onnxruntime-node)│ │
│ └─────────────┘ └──────────────┘ └───────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Result: { isSpam, confidence, score, modelVersion } │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## Training
### Prerequisites
- Python 3.10+ (tested with 3.14)
- Virtual environment: `python3 -m venv .venv-ml && source .venv-ml/bin/activate`
- Dependencies: `pip install transformers datasets torch optimum[onnxruntime] onnxruntime`
### Run Training
```bash
# Full pipeline with data augmentation
python ml/spam-classifier/train.py \
--output-dir ml/spam-classifier/output \
--augment \
--epochs 3 \
--batch-size 32
# Skip training (use existing model for export only)
python ml/spam-classifier/train.py \
--output-dir ml/spam-classifier/output \
--skip-training
# With INT8 quantization
python ml/spam-classifier/train.py \
--output-dir ml/spam-classifier/output \
--augment \
--quantize
```
### Training Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--epochs` | 3 | Number of training epochs |
| `--batch-size` | 32 | Training batch size |
| `--lr` | 2e-5 | Learning rate |
| `--augment` | false | Include synthetic data augmentation |
| `--quantize` | false | Quantize to INT8 for smaller model |
| `--skip-training` | false | Skip training, only export |
### Expected Results
- **Validation accuracy**: 97-99% on SMS Spam Collection
- **Training time**: ~5-10 minutes on M-series Mac, ~2-3 minutes on GPU
- **Model size**: ~257MB (FP32), ~128MB (INT8 quantized)
## Inference
### Setup
The trained ONNX model is deployed to `web/src/server/models/spam-classifier/`.
The Node.js inference wrapper is in `web/src/server/services/spamshield/onnx.inference.ts`.
### Usage
```typescript
import { classifyTextBERT, initSpamModel } from "~/server/services/spamshield/onnx.inference";
// Initialize at startup (loads model once)
await initSpamModel();
// Classify text
const result = await classifyTextBERT("Hello world");
// { isSpam: false, confidence: 0.97, score: 0.03, modelVersion: "1.0.0" }
// With threshold mode
const strict = await classifyTextBERT("Get free money!", "strict");
const lenient = await classifyTextBERT("Get free money!", "lenient");
```
### Threshold Modes
| Mode | Threshold | Description |
|------|-----------|-------------|
| `strict` | 0.3 | Aggressive spam detection, more false positives |
| `moderate` | 0.5 | Balanced (default) |
| `lenient` | 0.7 | Conservative, fewer false positives |
### Performance
| Metric | Value |
|--------|-------|
| p50 latency | <5ms |
| p95 latency | <30ms |
| p99 latency | <50ms |
| Throughput | ~240 inferences/sec |
| Model load time | ~2s (cold start) |
### Benchmark
```bash
cd web
pnpm benchmark:spamshield --iterations 1000
```
## Model Versioning
The model version is tracked in `model_metadata.json` alongside the ONNX files.
To deploy a new model:
1. Run training with new data/hyperparameters
2. Update version in `model_metadata.json`
3. Copy ONNX files to `web/src/server/models/spam-classifier/`
4. Deploy
## Feedback Loop
User feedback is stored in the `spam_feedback` database table and can be used for periodic retraining:
```sql
SELECT text, is_spam, feedback_type FROM spam_feedback
WHERE feedback_type = 'user_rejection'
ORDER BY created_at DESC;
```
## Files
```
ml/spam-classifier/
├── train.py # Training pipeline
├── README.md # This file
└── output/ # Training output (gitignored except onnx_model)
├── data/ # Raw dataset (gitignored)
├── final_model/ # PyTorch model (gitignored)
├── best_model/ # Best checkpoint (gitignored)
└── onnx_model/ # ONNX export (copied to web/)
├── model.onnx
├── model.onnx.data
├── vocab.txt
├── tokenizer.json
├── tokenizer_config.json
└── model_metadata.json
```