85 lines
4.5 KiB
Markdown
85 lines
4.5 KiB
Markdown
# 07. Fine-Tuned DistilBERT SMS Spam Classifier with ONNX Deployment
|
||
|
||
meta:
|
||
id: core-services-07
|
||
feature: core-services-implementation
|
||
priority: P1
|
||
depends_on: [core-services-06]
|
||
tags: [spamshield, ml, nlp, distilbert, onnx, text-classification]
|
||
|
||
objective:
|
||
- Replace the stub `classifyTextBERT()` function that returns `{ isSpam: false, confidence: 1.0 }` with a production ML pipeline: fine-tune DistilBERT on SMS spam data, export to ONNX for fast inference, and integrate into the spam classification flow.
|
||
|
||
deliverables:
|
||
- Training pipeline for fine-tuning DistilBERT on SMS spam dataset
|
||
- ONNX-exported model for low-latency CPU inference (~50ms per message)
|
||
- Inference server with batching and caching
|
||
- Integration with existing spam classification service
|
||
- Model versioning and A/B testing framework
|
||
|
||
steps:
|
||
1. Set up Python training environment:
|
||
- Install `transformers`, `datasets`, `onnxruntime`, `torch`, `optimum[onnxruntime]`
|
||
- Create `ml/spam-classifier/` directory in project root
|
||
2. Acquire training data:
|
||
- SMS Spam Collection Dataset (UCI ML Repository, 5,574 messages)
|
||
- Enron Spam Dataset (email corpus, filter to SMS-like short messages)
|
||
- Custom labeled data from user feedback (Phase 2)
|
||
3. Fine-tune DistilBERT-base-uncased:
|
||
- Binary classification: spam vs. ham
|
||
- 3 epochs, batch size 32, learning rate 2e-5
|
||
- Expected accuracy: 97–99% on SMS Spam Collection
|
||
4. Export to ONNX:
|
||
- Use Optimum CLI: `optimum-cli export onnx --model distilbert-spam ./onnx_model/`
|
||
- Quantize to INT8 for 2x speedup with minimal accuracy loss
|
||
- Target model size: ~65MB (DistilBERT base), ~33MB (INT8)
|
||
5. Create Node.js ONNX inference wrapper:
|
||
- Install `onnxruntime-node`
|
||
- Load model once at startup, reuse session
|
||
- Preprocess: tokenize with DistilBERT tokenizer (max length 128)
|
||
- Postprocess: sigmoid on logits → probability → binary decision
|
||
- Target latency: <50ms per message on CPU, <10ms on GPU
|
||
6. Integrate into `spamshield.service.ts`:
|
||
- Replace `classifyTextBERT()` call with real ONNX inference
|
||
- Classification flow: reputation lookup → rule engine → ML classifier (ensemble)
|
||
- Threshold tuning: default 0.5, adjustable per user preference
|
||
7. Implement feedback loop:
|
||
- User can report false positive/negative
|
||
- Store feedback in `spamFeedback` table (already exists)
|
||
- Weekly retraining batch using accumulated feedback
|
||
8. Add model versioning:
|
||
- Store model artifact in S3-compatible storage
|
||
- A/B test new models on subset of traffic
|
||
- Rollback capability if accuracy degrades
|
||
|
||
tests:
|
||
- Unit: Verify ONNX inference produces correct labels for known spam/ham test cases
|
||
- Integration: End-to-end classification flow with real model loading
|
||
- E2E: Submit SMS text → receive classification with confidence score
|
||
|
||
acceptance_criteria:
|
||
- [ ] `classifyTextBERT()` runs real ONNX inference (not returning hardcoded `{ isSpam: false }`)
|
||
- [ ] Model accuracy > 95% on held-out test set from SMS Spam Collection
|
||
- [ ] Inference latency < 50ms per message on CPU (measured in production)
|
||
- [ ] Model file is versioned and loadable from external storage (S3/local path)
|
||
- [ ] False positive rate < 2% (legitimate messages incorrectly flagged as spam)
|
||
- [ ] User feedback ("not spam" / "spam") is stored and used for model improvement
|
||
- [ ] Classification threshold is configurable per user (strict/moderate/lenient)
|
||
- [ ] ONNX model loads once at server startup, not per-request
|
||
- [ ] Graceful fallback to rule engine if ONNX runtime fails
|
||
- [ ] Model size < 100MB for reasonable cold-start time
|
||
|
||
validation:
|
||
- Run `vitest run spamshield.service.test.ts` — tests use real ONNX model
|
||
- Benchmark: `bun run benchmark:spamshield` — measure 1000 inferences, report p50/p95/p99 latency
|
||
- Manual: Classify known spam message "Congratulations! You've won $1000...", verify `isSpam: true, confidence > 0.9`
|
||
- Check feedback: Database `spamFeedback` table accumulates user corrections
|
||
|
||
notes:
|
||
- DistilBERT is chosen over BERT for 40% smaller size and 60% faster inference with minimal accuracy loss
|
||
- ONNX Runtime Node.js has limited platform support — test on your deployment target (Linux x64, macOS ARM)
|
||
- Training can happen in CI (GitHub Actions with GPU runner) or locally — inference happens in production
|
||
- Consider TensorFlow Lite or ONNX Runtime Web for on-device mobile inference later
|
||
- The SMS Spam Collection is small (5,574 messages) — augment with synthetic spam variants for robustness
|
||
- For European languages, consider multilingual model like `distilbert-base-multilingual-cased`
|