167 lines
6.4 KiB
Markdown
167 lines
6.4 KiB
Markdown
# SMS Spam Classifier - DistilBERT ONNX Pipeline
|
|
|
|
## Overview
|
|
|
|
This directory contains the training pipeline for a fine-tuned DistilBERT SMS spam classifier, exported to ONNX format for fast CPU inference in the Kordant spamshield service.
|
|
|
|
## Architecture
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────┐
|
|
│ Training (Python) │
|
|
│ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │
|
|
│ │ SMS Spam │ │ DistilBERT │ │ ONNX Export │ │
|
|
│ │ Collection │→ │ Fine-tuning │→ │ + INT8 Quantize │ │
|
|
│ │ Dataset │ │ (3 epochs) │ │ │ │
|
|
│ └─────────────┘ └──────────────┘ └───────────────────┘ │
|
|
└─────────────────────────────────────────────────────────────┘
|
|
↓
|
|
┌─────────────────────────────────────────────────────────────┐
|
|
│ Inference (Node.js) │
|
|
│ ┌─────────────┐ ┌──────────────┐ ┌───────────────────┐ │
|
|
│ │ Text Input │ │ BertTokenizer│ │ ONNX Runtime │ │
|
|
│ │ │→ │ (JS impl) │→ │ (onnxruntime-node)│ │
|
|
│ └─────────────┘ └──────────────┘ └───────────────────┘ │
|
|
│ ↓ │
|
|
│ ┌─────────────────────────────────────────────────────┐ │
|
|
│ │ Result: { isSpam, confidence, score, modelVersion } │ │
|
|
│ └─────────────────────────────────────────────────────┘ │
|
|
└─────────────────────────────────────────────────────────────┘
|
|
```
|
|
|
|
## Training
|
|
|
|
### Prerequisites
|
|
|
|
- Python 3.10+ (tested with 3.14)
|
|
- Virtual environment: `python3 -m venv .venv-ml && source .venv-ml/bin/activate`
|
|
- Dependencies: `pip install transformers datasets torch optimum[onnxruntime] onnxruntime`
|
|
|
|
### Run Training
|
|
|
|
```bash
|
|
# Full pipeline with data augmentation
|
|
python ml/spam-classifier/train.py \
|
|
--output-dir ml/spam-classifier/output \
|
|
--augment \
|
|
--epochs 3 \
|
|
--batch-size 32
|
|
|
|
# Skip training (use existing model for export only)
|
|
python ml/spam-classifier/train.py \
|
|
--output-dir ml/spam-classifier/output \
|
|
--skip-training
|
|
|
|
# With INT8 quantization
|
|
python ml/spam-classifier/train.py \
|
|
--output-dir ml/spam-classifier/output \
|
|
--augment \
|
|
--quantize
|
|
```
|
|
|
|
### Training Parameters
|
|
|
|
| Parameter | Default | Description |
|
|
|-----------|---------|-------------|
|
|
| `--epochs` | 3 | Number of training epochs |
|
|
| `--batch-size` | 32 | Training batch size |
|
|
| `--lr` | 2e-5 | Learning rate |
|
|
| `--augment` | false | Include synthetic data augmentation |
|
|
| `--quantize` | false | Quantize to INT8 for smaller model |
|
|
| `--skip-training` | false | Skip training, only export |
|
|
|
|
### Expected Results
|
|
|
|
- **Validation accuracy**: 97-99% on SMS Spam Collection
|
|
- **Training time**: ~5-10 minutes on M-series Mac, ~2-3 minutes on GPU
|
|
- **Model size**: ~257MB (FP32), ~128MB (INT8 quantized)
|
|
|
|
## Inference
|
|
|
|
### Setup
|
|
|
|
The trained ONNX model is deployed to `web/src/server/models/spam-classifier/`.
|
|
|
|
The Node.js inference wrapper is in `web/src/server/services/spamshield/onnx.inference.ts`.
|
|
|
|
### Usage
|
|
|
|
```typescript
|
|
import { classifyTextBERT, initSpamModel } from "~/server/services/spamshield/onnx.inference";
|
|
|
|
// Initialize at startup (loads model once)
|
|
await initSpamModel();
|
|
|
|
// Classify text
|
|
const result = await classifyTextBERT("Hello world");
|
|
// { isSpam: false, confidence: 0.97, score: 0.03, modelVersion: "1.0.0" }
|
|
|
|
// With threshold mode
|
|
const strict = await classifyTextBERT("Get free money!", "strict");
|
|
const lenient = await classifyTextBERT("Get free money!", "lenient");
|
|
```
|
|
|
|
### Threshold Modes
|
|
|
|
| Mode | Threshold | Description |
|
|
|------|-----------|-------------|
|
|
| `strict` | 0.3 | Aggressive spam detection, more false positives |
|
|
| `moderate` | 0.5 | Balanced (default) |
|
|
| `lenient` | 0.7 | Conservative, fewer false positives |
|
|
|
|
### Performance
|
|
|
|
| Metric | Value |
|
|
|--------|-------|
|
|
| p50 latency | <5ms |
|
|
| p95 latency | <30ms |
|
|
| p99 latency | <50ms |
|
|
| Throughput | ~240 inferences/sec |
|
|
| Model load time | ~2s (cold start) |
|
|
|
|
### Benchmark
|
|
|
|
```bash
|
|
cd web
|
|
pnpm benchmark:spamshield --iterations 1000
|
|
```
|
|
|
|
## Model Versioning
|
|
|
|
The model version is tracked in `model_metadata.json` alongside the ONNX files.
|
|
|
|
To deploy a new model:
|
|
1. Run training with new data/hyperparameters
|
|
2. Update version in `model_metadata.json`
|
|
3. Copy ONNX files to `web/src/server/models/spam-classifier/`
|
|
4. Deploy
|
|
|
|
## Feedback Loop
|
|
|
|
User feedback is stored in the `spam_feedback` database table and can be used for periodic retraining:
|
|
|
|
```sql
|
|
SELECT text, is_spam, feedback_type FROM spam_feedback
|
|
WHERE feedback_type = 'user_rejection'
|
|
ORDER BY created_at DESC;
|
|
```
|
|
|
|
## Files
|
|
|
|
```
|
|
ml/spam-classifier/
|
|
├── train.py # Training pipeline
|
|
├── README.md # This file
|
|
└── output/ # Training output (gitignored except onnx_model)
|
|
├── data/ # Raw dataset (gitignored)
|
|
├── final_model/ # PyTorch model (gitignored)
|
|
├── best_model/ # Best checkpoint (gitignored)
|
|
└── onnx_model/ # ONNX export (copied to web/)
|
|
├── model.onnx
|
|
├── model.onnx.data
|
|
├── vocab.txt
|
|
├── tokenizer.json
|
|
├── tokenizer_config.json
|
|
└── model_metadata.json
|
|
```
|