#!/usr/bin/env python3 """ Fine-tune DistilBERT on SMS Spam Collection Dataset and export to ONNX. Usage: python train.py [--output-dir ./output] [--epochs 3] [--batch-size 32] [--quantize] The SMS Spam Collection dataset is downloaded from UCI ML Repository. The fine-tuned model is exported to ONNX format for fast CPU inference. """ import argparse import json import os import sys from pathlib import Path from typing import Optional import numpy as np import torch from torch.optim.adamw import AdamW from datasets import Dataset, DatasetDict, concatenate_datasets from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, DistilBertForSequenceClassification, get_linear_schedule_with_warmup, ) # ── Constants ────────────────────────────────────────────────────────────── MODEL_NAME = "distilbert-base-uncased" MAX_LENGTH = 128 LABEL2ID = {"ham": 0, "spam": 1} ID2LABEL = {0: "ham", 1: "spam"} BATCH_SIZE = 32 NUM_EPOCHS = 3 LEARNING_RATE = 2e-5 WARMUP_RATIO = 0.1 EVAL_SPLIT = 0.15 # 15% for validation # ── Data Loading ─────────────────────────────────────────────────────────── HF_DATASET_NAME = "sms_spam" def load_dataset(data_dir: Optional[Path] = None) -> DatasetDict: """Load the SMS Spam Collection dataset from HuggingFace Hub.""" print(f" Loading {HF_DATASET_NAME} from HuggingFace Hub...") from datasets import load_dataset as hf_load hf_dataset = hf_load(HF_DATASET_NAME, split="train") # Rename columns to match our schema hf_dataset = hf_dataset.rename_column("sms", "text") # HF dataset uses 0=ham, 1=spam already print(f" Total messages: {len(hf_dataset)}") print(f" Spam: {sum(hf_dataset['label'])}, Ham: {len(hf_dataset) - sum(hf_dataset['label'])}") # Split into train/val split = hf_dataset.train_test_split(test_size=EVAL_SPLIT, seed=42) return DatasetDict({"train": split["train"], "val": split["test"]}) # ── Synthetic Data Augmentation ──────────────────────────────────────────── SPAM_TEMPLATES = [ "Congratulations! You've won ${amount}! Call {phone} to claim.", "FREE {product}! Text YES to {number} now!", "You have been selected for a ${amount} loan. Reply NOW.", "URGENT: Your account will be closed. Call {phone} immediately.", "Win a FREE iPhone! Visit {url} to claim your prize.", "Dear customer, your ${amount} refund is ready. Click {url}.", "Hot singles in your area! Text {number} to meet them.", "You've been pre-approved for a credit line of ${amount}.", "ACT NOW! Limited time offer: {product} for just ${amount}.", "Your package delivery failed. Call {phone} to reschedule.", "BREAKING: You won the lottery! ${amount} prize. Call {phone}.", "FREE entry to win ${amount} cash prize! Text WIN to {number}.", "Your bank account has been compromised. Call {phone} now.", "Make ${amount} per week working from home. Visit {url}.", "You have a new message from {name}. Reply to view.", ] HAM_TEMPLATES = [ "Hey, are we still on for dinner tonight?", "Can you pick up milk on your way home?", "Meeting at 3pm in conference room B.", "Thanks for the help yesterday!", "I'll be 10 minutes late to the meeting.", "Let me know when you get home.", "Happy birthday! Hope you have a great day.", "The weather is nice today, want to go for a walk?", "I sent you the files, let me know if you got them.", "See you at the game tomorrow!", "Can we reschedule our appointment to next week?", "I forgot my keys, can you let me in?", "The presentation went well, thanks for the feedback.", "Do you have the recipe for that cake?", "Running late, be there in 5 minutes.", ] PHONE_PLACEHOLDERS = ["555-0123", "800-555-1234", "555-9876"] URL_PLACEHOLDERS = ["http://bit.ly/abc123", "http://short.url/xyz", "http://link.co/prize"] AMOUNT_PLACEHOLDERS = ["1000", "5000", "10000", "500", "25000"] PRODUCT_PLACEHOLDERS = ["iPad", "TV", "Gift Card", "Headphones", "Watch"] NAME_PLACEHOLDERS = ["John", "Sarah", "Mike", "Jessica", "David"] def generate_synthetic_data(num_samples: int = 1000) -> dict: """Generate synthetic spam and ham messages for data augmentation.""" import random random.seed(42) texts = [] labels = [] for _ in range(num_samples): if random.random() < 0.5: # Generate spam template = random.choice(SPAM_TEMPLATES) msg = template.replace("${amount}", random.choice(AMOUNT_PLACEHOLDERS)) msg = msg.replace("{phone}", random.choice(PHONE_PLACEHOLDERS)) msg = msg.replace("{url}", random.choice(URL_PLACEHOLDERS)) msg = msg.replace("{product}", random.choice(PRODUCT_PLACEHOLDERS)) msg = msg.replace("{name}", random.choice(NAME_PLACEHOLDERS)) msg = msg.replace("{number}", random.choice(PHONE_PLACEHOLDERS)) texts.append(msg) labels.append(1) else: # Generate ham texts.append(random.choice(HAM_TEMPLATES)) labels.append(0) return {"text": texts, "label": labels} # ── Training ─────────────────────────────────────────────────────────────── def train_model( train_dataset: Dataset, val_dataset: Dataset, output_dir: Path, epochs: int = NUM_EPOCHS, batch_size: int = BATCH_SIZE, learning_rate: float = LEARNING_RATE, ): """Fine-tune DistilBERT for SMS spam classification.""" print("\n--- Training DistilBERT ---") print(f" Model: {MODEL_NAME}") print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {learning_rate}") device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") print(f" Device: {device}") # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = DistilBertForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=2 ) model.to(device) # Tokenize datasets def tokenize_fn(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH, ) train_dataset = train_dataset.map(tokenize_fn, batched=True) val_dataset = val_dataset.map(tokenize_fn, batched=True) train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"]) val_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"]) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size) # Optimizer and scheduler optimizer = AdamW(model.parameters(), lr=learning_rate) total_steps = len(train_loader) * epochs warmup_steps = int(total_steps * WARMUP_RATIO) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps ) # Training loop model.train() best_val_loss = float("inf") best_model_path = output_dir / "best_model" for epoch in range(epochs): total_loss = 0 for batch in train_loader: optimizer.zero_grad() input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["label"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) loss = outputs.loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) # Validation model.eval() val_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch in val_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["label"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) val_loss += outputs.loss.item() predictions = torch.argmax(outputs.logits, dim=-1) correct += (predictions == labels).sum().item() total += labels.size(0) model.train() val_acc = correct / total avg_val_loss = val_loss / len(val_loader) print( f" Epoch {epoch + 1}/{epochs}: " f"Train Loss={avg_loss:.4f}, Val Loss={avg_val_loss:.4f}, " f"Val Acc={val_acc:.4f}" ) # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save_pretrained(best_model_path) print(f" -> Saved best model (val_loss={avg_val_loss:.4f})") # Save final model too final_model_path = output_dir / "final_model" model.save_pretrained(final_model_path) tokenizer.save_pretrained(final_model_path) print(f"\n--- Training complete ---") print(f" Best val loss: {best_val_loss:.4f}") print(f" Final val accuracy: {val_acc:.4f}") print(f" Model saved to: {final_model_path}") return model, tokenizer, val_acc # ── ONNX Export ──────────────────────────────────────────────────────────── def export_onnx( model, tokenizer, output_dir: Path, quantize: bool = False, ): """Export the model to ONNX format.""" print("\n--- Exporting to ONNX ---") onnx_dir = output_dir / "onnx_model" onnx_dir.mkdir(parents=True, exist_ok=True) model.eval() device = next(model.parameters()).device model.to("cpu") # ONNX export needs CPU # Simple approach: use torch.onnx.export with dynamic axes print(" Using torch.onnx.export for ONNX conversion...") onnx_path = onnx_dir / "model.onnx" # Create dummy inputs with batch size 1 dummy_input = tokenizer( "This is a test message for spam detection.", return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LENGTH, ) torch.onnx.export( model, (dummy_input["input_ids"], dummy_input["attention_mask"]), onnx_path, input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={ "input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "logits": {0: "batch_size"}, }, opset_version=18, do_constant_folding=True, ) print(f" ONNX model exported to: {onnx_path}") print(f" Model size: {onnx_path.stat().st_size / 1024 / 1024:.1f} MB") # Optional: INT8 quantization if quantize: try: from onnxruntime.quantization import quantize_dynamic from onnxruntime.quantization.quantize import QuantType print("\n --- Quantizing to INT8 ---") quantized_path = onnx_dir / "model_quantized.onnx" quantize_dynamic( model_input=onnx_path, model_output=quantized_path, weight_type=QuantType.QUINT8, ) print(f" Quantized model: {quantized_path}") print(f" Quantized size: {quantized_path.stat().st_size / 1024 / 1024:.1f} MB") except ImportError: print(" Skipping quantization (onnxruntime quantization not available)") # Save tokenizer files needed by Node.js tokenizer tokenizer.save_pretrained(str(onnx_dir)) # Save model metadata metadata = { "version": "1.0.0", "model_name": MODEL_NAME, "task": "sms-spam-classification", "max_length": MAX_LENGTH, "num_labels": 2, "label2id": LABEL2ID, "id2label": ID2LABEL, "framework": "pytorch", "export_format": "onnx", } metadata_path = onnx_dir / "model_metadata.json" with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) print(f" Metadata saved to: {metadata_path}") return onnx_dir # ── Main ─────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="Train DistilBERT SMS Spam Classifier") parser.add_argument("--output-dir", type=str, default="./output", help="Output directory") parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="Number of training epochs") parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size") parser.add_argument("--lr", type=float, default=LEARNING_RATE, help="Learning rate") parser.add_argument("--quantize", action="store_true", help="Quantize to INT8") parser.add_argument("--skip-training", action="store_true", help="Skip training, only export") parser.add_argument("--augment", action="store_true", help="Include synthetic data augmentation") args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print("=" * 60) print("DistilBERT SMS Spam Classifier Training Pipeline") print("=" * 60) # Step 1: Load dataset print("\n--- Step 1: Loading Dataset ---") dataset_dict = load_dataset() # Step 2: Optional synthetic augmentation if args.augment: print("\n--- Step 2: Data Augmentation ---") synthetic = generate_synthetic_data(num_samples=1000) synthetic_ds = Dataset.from_dict(synthetic) # Cast label to match train dataset's label type train_ds = dataset_dict["train"] synthetic_ds = synthetic_ds.cast_column("label", train_ds.features["label"]) augmented = concatenate_datasets([train_ds, synthetic_ds]) dataset_dict["train"] = augmented.shuffle(seed=42) print(f" Augmented train set: {len(dataset_dict['train'])} samples") # Step 3: Train model model = None tokenizer = None val_acc = 0.0 if not args.skip_training: print("\n--- Step 3: Training ---") model, tokenizer, val_acc = train_model( train_dataset=dataset_dict["train"], val_dataset=dataset_dict["val"], output_dir=output_dir, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, ) if val_acc < 0.95: print(f"\n WARNING: Validation accuracy {val_acc:.4f} is below 0.95 target") print(" Consider increasing epochs or adjusting hyperparameters") else: print("\n--- Step 3: Skipping training (--skip-training) ---") # Load existing model for export final_model_path = output_dir / "final_model" if final_model_path.exists(): print(f" Loading model from {final_model_path}") model = DistilBertForSequenceClassification.from_pretrained( str(final_model_path), num_labels=2 ) tokenizer = AutoTokenizer.from_pretrained(str(final_model_path)) else: print(f" No existing model found at {final_model_path}") print(f" Loading base {MODEL_NAME} for export...") model = DistilBertForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=2 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Step 4: Export to ONNX print("\n--- Step 4: ONNX Export ---") onnx_dir = export_onnx(model, tokenizer, output_dir, quantize=args.quantize) print("\n" + "=" * 60) print("Pipeline complete!") print(f" ONNX model: {onnx_dir / 'model.onnx'}") print(f" Tokenizer: {onnx_dir}") print(f" Validation accuracy: {val_acc:.4f}") print("=" * 60) if __name__ == "__main__": main()