Files
Kordant/ml/spam-classifier/train.py

445 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Fine-tune DistilBERT on SMS Spam Collection Dataset and export to ONNX.
Usage:
python train.py [--output-dir ./output] [--epochs 3] [--batch-size 32] [--quantize]
The SMS Spam Collection dataset is downloaded from UCI ML Repository.
The fine-tuned model is exported to ONNX format for fast CPU inference.
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from torch.optim.adamw import AdamW
from datasets import Dataset, DatasetDict, concatenate_datasets
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
DistilBertForSequenceClassification,
get_linear_schedule_with_warmup,
)
# ── Constants ──────────────────────────────────────────────────────────────
MODEL_NAME = "distilbert-base-uncased"
MAX_LENGTH = 128
LABEL2ID = {"ham": 0, "spam": 1}
ID2LABEL = {0: "ham", 1: "spam"}
BATCH_SIZE = 32
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5
WARMUP_RATIO = 0.1
EVAL_SPLIT = 0.15 # 15% for validation
# ── Data Loading ───────────────────────────────────────────────────────────
HF_DATASET_NAME = "sms_spam"
def load_dataset(data_dir: Optional[Path] = None) -> DatasetDict:
"""Load the SMS Spam Collection dataset from HuggingFace Hub."""
print(f" Loading {HF_DATASET_NAME} from HuggingFace Hub...")
from datasets import load_dataset as hf_load
hf_dataset = hf_load(HF_DATASET_NAME, split="train")
# Rename columns to match our schema
hf_dataset = hf_dataset.rename_column("sms", "text")
# HF dataset uses 0=ham, 1=spam already
print(f" Total messages: {len(hf_dataset)}")
print(f" Spam: {sum(hf_dataset['label'])}, Ham: {len(hf_dataset) - sum(hf_dataset['label'])}")
# Split into train/val
split = hf_dataset.train_test_split(test_size=EVAL_SPLIT, seed=42)
return DatasetDict({"train": split["train"], "val": split["test"]})
# ── Synthetic Data Augmentation ────────────────────────────────────────────
SPAM_TEMPLATES = [
"Congratulations! You've won ${amount}! Call {phone} to claim.",
"FREE {product}! Text YES to {number} now!",
"You have been selected for a ${amount} loan. Reply NOW.",
"URGENT: Your account will be closed. Call {phone} immediately.",
"Win a FREE iPhone! Visit {url} to claim your prize.",
"Dear customer, your ${amount} refund is ready. Click {url}.",
"Hot singles in your area! Text {number} to meet them.",
"You've been pre-approved for a credit line of ${amount}.",
"ACT NOW! Limited time offer: {product} for just ${amount}.",
"Your package delivery failed. Call {phone} to reschedule.",
"BREAKING: You won the lottery! ${amount} prize. Call {phone}.",
"FREE entry to win ${amount} cash prize! Text WIN to {number}.",
"Your bank account has been compromised. Call {phone} now.",
"Make ${amount} per week working from home. Visit {url}.",
"You have a new message from {name}. Reply to view.",
]
HAM_TEMPLATES = [
"Hey, are we still on for dinner tonight?",
"Can you pick up milk on your way home?",
"Meeting at 3pm in conference room B.",
"Thanks for the help yesterday!",
"I'll be 10 minutes late to the meeting.",
"Let me know when you get home.",
"Happy birthday! Hope you have a great day.",
"The weather is nice today, want to go for a walk?",
"I sent you the files, let me know if you got them.",
"See you at the game tomorrow!",
"Can we reschedule our appointment to next week?",
"I forgot my keys, can you let me in?",
"The presentation went well, thanks for the feedback.",
"Do you have the recipe for that cake?",
"Running late, be there in 5 minutes.",
]
PHONE_PLACEHOLDERS = ["555-0123", "800-555-1234", "555-9876"]
URL_PLACEHOLDERS = ["http://bit.ly/abc123", "http://short.url/xyz", "http://link.co/prize"]
AMOUNT_PLACEHOLDERS = ["1000", "5000", "10000", "500", "25000"]
PRODUCT_PLACEHOLDERS = ["iPad", "TV", "Gift Card", "Headphones", "Watch"]
NAME_PLACEHOLDERS = ["John", "Sarah", "Mike", "Jessica", "David"]
def generate_synthetic_data(num_samples: int = 1000) -> dict:
"""Generate synthetic spam and ham messages for data augmentation."""
import random
random.seed(42)
texts = []
labels = []
for _ in range(num_samples):
if random.random() < 0.5:
# Generate spam
template = random.choice(SPAM_TEMPLATES)
msg = template.replace("${amount}", random.choice(AMOUNT_PLACEHOLDERS))
msg = msg.replace("{phone}", random.choice(PHONE_PLACEHOLDERS))
msg = msg.replace("{url}", random.choice(URL_PLACEHOLDERS))
msg = msg.replace("{product}", random.choice(PRODUCT_PLACEHOLDERS))
msg = msg.replace("{name}", random.choice(NAME_PLACEHOLDERS))
msg = msg.replace("{number}", random.choice(PHONE_PLACEHOLDERS))
texts.append(msg)
labels.append(1)
else:
# Generate ham
texts.append(random.choice(HAM_TEMPLATES))
labels.append(0)
return {"text": texts, "label": labels}
# ── Training ───────────────────────────────────────────────────────────────
def train_model(
train_dataset: Dataset,
val_dataset: Dataset,
output_dir: Path,
epochs: int = NUM_EPOCHS,
batch_size: int = BATCH_SIZE,
learning_rate: float = LEARNING_RATE,
):
"""Fine-tune DistilBERT for SMS spam classification."""
print("\n--- Training DistilBERT ---")
print(f" Model: {MODEL_NAME}")
print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {learning_rate}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(
MODEL_NAME, num_labels=2
)
model.to(device)
# Tokenize datasets
def tokenize_fn(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
)
train_dataset = train_dataset.map(tokenize_fn, batched=True)
val_dataset = val_dataset.map(tokenize_fn, batched=True)
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * epochs
warmup_steps = int(total_steps * WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
# Training loop
model.train()
best_val_loss = float("inf")
best_model_path = output_dir / "best_model"
for epoch in range(epochs):
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
# Validation
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
val_loss += outputs.loss.item()
predictions = torch.argmax(outputs.logits, dim=-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
model.train()
val_acc = correct / total
avg_val_loss = val_loss / len(val_loader)
print(
f" Epoch {epoch + 1}/{epochs}: "
f"Train Loss={avg_loss:.4f}, Val Loss={avg_val_loss:.4f}, "
f"Val Acc={val_acc:.4f}"
)
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save_pretrained(best_model_path)
print(f" -> Saved best model (val_loss={avg_val_loss:.4f})")
# Save final model too
final_model_path = output_dir / "final_model"
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"\n--- Training complete ---")
print(f" Best val loss: {best_val_loss:.4f}")
print(f" Final val accuracy: {val_acc:.4f}")
print(f" Model saved to: {final_model_path}")
return model, tokenizer, val_acc
# ── ONNX Export ────────────────────────────────────────────────────────────
def export_onnx(
model,
tokenizer,
output_dir: Path,
quantize: bool = False,
):
"""Export the model to ONNX format."""
print("\n--- Exporting to ONNX ---")
onnx_dir = output_dir / "onnx_model"
onnx_dir.mkdir(parents=True, exist_ok=True)
model.eval()
device = next(model.parameters()).device
model.to("cpu") # ONNX export needs CPU
# Simple approach: use torch.onnx.export with dynamic axes
print(" Using torch.onnx.export for ONNX conversion...")
onnx_path = onnx_dir / "model.onnx"
# Create dummy inputs with batch size 1
dummy_input = tokenizer(
"This is a test message for spam detection.",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
)
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
onnx_path,
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=18,
do_constant_folding=True,
)
print(f" ONNX model exported to: {onnx_path}")
print(f" Model size: {onnx_path.stat().st_size / 1024 / 1024:.1f} MB")
# Optional: INT8 quantization
if quantize:
try:
from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization.quantize import QuantType
print("\n --- Quantizing to INT8 ---")
quantized_path = onnx_dir / "model_quantized.onnx"
quantize_dynamic(
model_input=onnx_path,
model_output=quantized_path,
weight_type=QuantType.QUINT8,
)
print(f" Quantized model: {quantized_path}")
print(f" Quantized size: {quantized_path.stat().st_size / 1024 / 1024:.1f} MB")
except ImportError:
print(" Skipping quantization (onnxruntime quantization not available)")
# Save tokenizer files needed by Node.js tokenizer
tokenizer.save_pretrained(str(onnx_dir))
# Save model metadata
metadata = {
"version": "1.0.0",
"model_name": MODEL_NAME,
"task": "sms-spam-classification",
"max_length": MAX_LENGTH,
"num_labels": 2,
"label2id": LABEL2ID,
"id2label": ID2LABEL,
"framework": "pytorch",
"export_format": "onnx",
}
metadata_path = onnx_dir / "model_metadata.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f" Metadata saved to: {metadata_path}")
return onnx_dir
# ── Main ───────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Train DistilBERT SMS Spam Classifier")
parser.add_argument("--output-dir", type=str, default="./output", help="Output directory")
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size")
parser.add_argument("--lr", type=float, default=LEARNING_RATE, help="Learning rate")
parser.add_argument("--quantize", action="store_true", help="Quantize to INT8")
parser.add_argument("--skip-training", action="store_true", help="Skip training, only export")
parser.add_argument("--augment", action="store_true", help="Include synthetic data augmentation")
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print("=" * 60)
print("DistilBERT SMS Spam Classifier Training Pipeline")
print("=" * 60)
# Step 1: Load dataset
print("\n--- Step 1: Loading Dataset ---")
dataset_dict = load_dataset()
# Step 2: Optional synthetic augmentation
if args.augment:
print("\n--- Step 2: Data Augmentation ---")
synthetic = generate_synthetic_data(num_samples=1000)
synthetic_ds = Dataset.from_dict(synthetic)
# Cast label to match train dataset's label type
train_ds = dataset_dict["train"]
synthetic_ds = synthetic_ds.cast_column("label", train_ds.features["label"])
augmented = concatenate_datasets([train_ds, synthetic_ds])
dataset_dict["train"] = augmented.shuffle(seed=42)
print(f" Augmented train set: {len(dataset_dict['train'])} samples")
# Step 3: Train model
model = None
tokenizer = None
val_acc = 0.0
if not args.skip_training:
print("\n--- Step 3: Training ---")
model, tokenizer, val_acc = train_model(
train_dataset=dataset_dict["train"],
val_dataset=dataset_dict["val"],
output_dir=output_dir,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
)
if val_acc < 0.95:
print(f"\n WARNING: Validation accuracy {val_acc:.4f} is below 0.95 target")
print(" Consider increasing epochs or adjusting hyperparameters")
else:
print("\n--- Step 3: Skipping training (--skip-training) ---")
# Load existing model for export
final_model_path = output_dir / "final_model"
if final_model_path.exists():
print(f" Loading model from {final_model_path}")
model = DistilBertForSequenceClassification.from_pretrained(
str(final_model_path), num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(str(final_model_path))
else:
print(f" No existing model found at {final_model_path}")
print(f" Loading base {MODEL_NAME} for export...")
model = DistilBertForSequenceClassification.from_pretrained(
MODEL_NAME, num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Step 4: Export to ONNX
print("\n--- Step 4: ONNX Export ---")
onnx_dir = export_onnx(model, tokenizer, output_dir, quantize=args.quantize)
print("\n" + "=" * 60)
print("Pipeline complete!")
print(f" ONNX model: {onnx_dir / 'model.onnx'}")
print(f" Tokenizer: {onnx_dir}")
print(f" Validation accuracy: {val_acc:.4f}")
print("=" * 60)
if __name__ == "__main__":
main()