445 lines
16 KiB
Python
445 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Fine-tune DistilBERT on SMS Spam Collection Dataset and export to ONNX.
|
|
|
|
Usage:
|
|
python train.py [--output-dir ./output] [--epochs 3] [--batch-size 32] [--quantize]
|
|
|
|
The SMS Spam Collection dataset is downloaded from UCI ML Repository.
|
|
The fine-tuned model is exported to ONNX format for fast CPU inference.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim.adamw import AdamW
|
|
from datasets import Dataset, DatasetDict, concatenate_datasets
|
|
from torch.utils.data import DataLoader
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
DistilBertForSequenceClassification,
|
|
get_linear_schedule_with_warmup,
|
|
)
|
|
|
|
# ── Constants ──────────────────────────────────────────────────────────────
|
|
|
|
MODEL_NAME = "distilbert-base-uncased"
|
|
MAX_LENGTH = 128
|
|
LABEL2ID = {"ham": 0, "spam": 1}
|
|
ID2LABEL = {0: "ham", 1: "spam"}
|
|
BATCH_SIZE = 32
|
|
NUM_EPOCHS = 3
|
|
LEARNING_RATE = 2e-5
|
|
WARMUP_RATIO = 0.1
|
|
EVAL_SPLIT = 0.15 # 15% for validation
|
|
|
|
# ── Data Loading ───────────────────────────────────────────────────────────
|
|
|
|
HF_DATASET_NAME = "sms_spam"
|
|
|
|
|
|
def load_dataset(data_dir: Optional[Path] = None) -> DatasetDict:
|
|
"""Load the SMS Spam Collection dataset from HuggingFace Hub."""
|
|
print(f" Loading {HF_DATASET_NAME} from HuggingFace Hub...")
|
|
from datasets import load_dataset as hf_load
|
|
|
|
hf_dataset = hf_load(HF_DATASET_NAME, split="train")
|
|
# Rename columns to match our schema
|
|
hf_dataset = hf_dataset.rename_column("sms", "text")
|
|
# HF dataset uses 0=ham, 1=spam already
|
|
print(f" Total messages: {len(hf_dataset)}")
|
|
print(f" Spam: {sum(hf_dataset['label'])}, Ham: {len(hf_dataset) - sum(hf_dataset['label'])}")
|
|
|
|
# Split into train/val
|
|
split = hf_dataset.train_test_split(test_size=EVAL_SPLIT, seed=42)
|
|
return DatasetDict({"train": split["train"], "val": split["test"]})
|
|
|
|
|
|
# ── Synthetic Data Augmentation ────────────────────────────────────────────
|
|
|
|
SPAM_TEMPLATES = [
|
|
"Congratulations! You've won ${amount}! Call {phone} to claim.",
|
|
"FREE {product}! Text YES to {number} now!",
|
|
"You have been selected for a ${amount} loan. Reply NOW.",
|
|
"URGENT: Your account will be closed. Call {phone} immediately.",
|
|
"Win a FREE iPhone! Visit {url} to claim your prize.",
|
|
"Dear customer, your ${amount} refund is ready. Click {url}.",
|
|
"Hot singles in your area! Text {number} to meet them.",
|
|
"You've been pre-approved for a credit line of ${amount}.",
|
|
"ACT NOW! Limited time offer: {product} for just ${amount}.",
|
|
"Your package delivery failed. Call {phone} to reschedule.",
|
|
"BREAKING: You won the lottery! ${amount} prize. Call {phone}.",
|
|
"FREE entry to win ${amount} cash prize! Text WIN to {number}.",
|
|
"Your bank account has been compromised. Call {phone} now.",
|
|
"Make ${amount} per week working from home. Visit {url}.",
|
|
"You have a new message from {name}. Reply to view.",
|
|
]
|
|
|
|
HAM_TEMPLATES = [
|
|
"Hey, are we still on for dinner tonight?",
|
|
"Can you pick up milk on your way home?",
|
|
"Meeting at 3pm in conference room B.",
|
|
"Thanks for the help yesterday!",
|
|
"I'll be 10 minutes late to the meeting.",
|
|
"Let me know when you get home.",
|
|
"Happy birthday! Hope you have a great day.",
|
|
"The weather is nice today, want to go for a walk?",
|
|
"I sent you the files, let me know if you got them.",
|
|
"See you at the game tomorrow!",
|
|
"Can we reschedule our appointment to next week?",
|
|
"I forgot my keys, can you let me in?",
|
|
"The presentation went well, thanks for the feedback.",
|
|
"Do you have the recipe for that cake?",
|
|
"Running late, be there in 5 minutes.",
|
|
]
|
|
|
|
PHONE_PLACEHOLDERS = ["555-0123", "800-555-1234", "555-9876"]
|
|
URL_PLACEHOLDERS = ["http://bit.ly/abc123", "http://short.url/xyz", "http://link.co/prize"]
|
|
AMOUNT_PLACEHOLDERS = ["1000", "5000", "10000", "500", "25000"]
|
|
PRODUCT_PLACEHOLDERS = ["iPad", "TV", "Gift Card", "Headphones", "Watch"]
|
|
NAME_PLACEHOLDERS = ["John", "Sarah", "Mike", "Jessica", "David"]
|
|
|
|
|
|
def generate_synthetic_data(num_samples: int = 1000) -> dict:
|
|
"""Generate synthetic spam and ham messages for data augmentation."""
|
|
import random
|
|
|
|
random.seed(42)
|
|
texts = []
|
|
labels = []
|
|
|
|
for _ in range(num_samples):
|
|
if random.random() < 0.5:
|
|
# Generate spam
|
|
template = random.choice(SPAM_TEMPLATES)
|
|
msg = template.replace("${amount}", random.choice(AMOUNT_PLACEHOLDERS))
|
|
msg = msg.replace("{phone}", random.choice(PHONE_PLACEHOLDERS))
|
|
msg = msg.replace("{url}", random.choice(URL_PLACEHOLDERS))
|
|
msg = msg.replace("{product}", random.choice(PRODUCT_PLACEHOLDERS))
|
|
msg = msg.replace("{name}", random.choice(NAME_PLACEHOLDERS))
|
|
msg = msg.replace("{number}", random.choice(PHONE_PLACEHOLDERS))
|
|
texts.append(msg)
|
|
labels.append(1)
|
|
else:
|
|
# Generate ham
|
|
texts.append(random.choice(HAM_TEMPLATES))
|
|
labels.append(0)
|
|
|
|
return {"text": texts, "label": labels}
|
|
|
|
|
|
# ── Training ───────────────────────────────────────────────────────────────
|
|
|
|
def train_model(
|
|
train_dataset: Dataset,
|
|
val_dataset: Dataset,
|
|
output_dir: Path,
|
|
epochs: int = NUM_EPOCHS,
|
|
batch_size: int = BATCH_SIZE,
|
|
learning_rate: float = LEARNING_RATE,
|
|
):
|
|
"""Fine-tune DistilBERT for SMS spam classification."""
|
|
print("\n--- Training DistilBERT ---")
|
|
print(f" Model: {MODEL_NAME}")
|
|
print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {learning_rate}")
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f" Device: {device}")
|
|
|
|
# Load tokenizer and model
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = DistilBertForSequenceClassification.from_pretrained(
|
|
MODEL_NAME, num_labels=2
|
|
)
|
|
model.to(device)
|
|
|
|
# Tokenize datasets
|
|
def tokenize_fn(examples):
|
|
return tokenizer(
|
|
examples["text"],
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=MAX_LENGTH,
|
|
)
|
|
|
|
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
|
val_dataset = val_dataset.map(tokenize_fn, batched=True)
|
|
|
|
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
|
val_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size)
|
|
|
|
# Optimizer and scheduler
|
|
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
|
total_steps = len(train_loader) * epochs
|
|
warmup_steps = int(total_steps * WARMUP_RATIO)
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
|
|
)
|
|
|
|
# Training loop
|
|
model.train()
|
|
best_val_loss = float("inf")
|
|
best_model_path = output_dir / "best_model"
|
|
|
|
for epoch in range(epochs):
|
|
total_loss = 0
|
|
for batch in train_loader:
|
|
optimizer.zero_grad()
|
|
input_ids = batch["input_ids"].to(device)
|
|
attention_mask = batch["attention_mask"].to(device)
|
|
labels = batch["label"].to(device)
|
|
|
|
outputs = model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
)
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(train_loader)
|
|
|
|
# Validation
|
|
model.eval()
|
|
val_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
input_ids = batch["input_ids"].to(device)
|
|
attention_mask = batch["attention_mask"].to(device)
|
|
labels = batch["label"].to(device)
|
|
|
|
outputs = model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
)
|
|
val_loss += outputs.loss.item()
|
|
|
|
predictions = torch.argmax(outputs.logits, dim=-1)
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
|
|
model.train()
|
|
val_acc = correct / total
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
|
|
print(
|
|
f" Epoch {epoch + 1}/{epochs}: "
|
|
f"Train Loss={avg_loss:.4f}, Val Loss={avg_val_loss:.4f}, "
|
|
f"Val Acc={val_acc:.4f}"
|
|
)
|
|
|
|
# Save best model
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
model.save_pretrained(best_model_path)
|
|
print(f" -> Saved best model (val_loss={avg_val_loss:.4f})")
|
|
|
|
# Save final model too
|
|
final_model_path = output_dir / "final_model"
|
|
model.save_pretrained(final_model_path)
|
|
tokenizer.save_pretrained(final_model_path)
|
|
|
|
print(f"\n--- Training complete ---")
|
|
print(f" Best val loss: {best_val_loss:.4f}")
|
|
print(f" Final val accuracy: {val_acc:.4f}")
|
|
print(f" Model saved to: {final_model_path}")
|
|
|
|
return model, tokenizer, val_acc
|
|
|
|
|
|
# ── ONNX Export ────────────────────────────────────────────────────────────
|
|
|
|
def export_onnx(
|
|
model,
|
|
tokenizer,
|
|
output_dir: Path,
|
|
quantize: bool = False,
|
|
):
|
|
"""Export the model to ONNX format."""
|
|
print("\n--- Exporting to ONNX ---")
|
|
|
|
onnx_dir = output_dir / "onnx_model"
|
|
onnx_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model.eval()
|
|
device = next(model.parameters()).device
|
|
model.to("cpu") # ONNX export needs CPU
|
|
|
|
# Simple approach: use torch.onnx.export with dynamic axes
|
|
print(" Using torch.onnx.export for ONNX conversion...")
|
|
onnx_path = onnx_dir / "model.onnx"
|
|
|
|
# Create dummy inputs with batch size 1
|
|
dummy_input = tokenizer(
|
|
"This is a test message for spam detection.",
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=MAX_LENGTH,
|
|
)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
(dummy_input["input_ids"], dummy_input["attention_mask"]),
|
|
onnx_path,
|
|
input_names=["input_ids", "attention_mask"],
|
|
output_names=["logits"],
|
|
dynamic_axes={
|
|
"input_ids": {0: "batch_size"},
|
|
"attention_mask": {0: "batch_size"},
|
|
"logits": {0: "batch_size"},
|
|
},
|
|
opset_version=18,
|
|
do_constant_folding=True,
|
|
)
|
|
|
|
print(f" ONNX model exported to: {onnx_path}")
|
|
print(f" Model size: {onnx_path.stat().st_size / 1024 / 1024:.1f} MB")
|
|
|
|
# Optional: INT8 quantization
|
|
if quantize:
|
|
try:
|
|
from onnxruntime.quantization import quantize_dynamic
|
|
from onnxruntime.quantization.quantize import QuantType
|
|
|
|
print("\n --- Quantizing to INT8 ---")
|
|
quantized_path = onnx_dir / "model_quantized.onnx"
|
|
|
|
quantize_dynamic(
|
|
model_input=onnx_path,
|
|
model_output=quantized_path,
|
|
weight_type=QuantType.QUINT8,
|
|
)
|
|
|
|
print(f" Quantized model: {quantized_path}")
|
|
print(f" Quantized size: {quantized_path.stat().st_size / 1024 / 1024:.1f} MB")
|
|
|
|
except ImportError:
|
|
print(" Skipping quantization (onnxruntime quantization not available)")
|
|
|
|
# Save tokenizer files needed by Node.js tokenizer
|
|
tokenizer.save_pretrained(str(onnx_dir))
|
|
|
|
# Save model metadata
|
|
metadata = {
|
|
"version": "1.0.0",
|
|
"model_name": MODEL_NAME,
|
|
"task": "sms-spam-classification",
|
|
"max_length": MAX_LENGTH,
|
|
"num_labels": 2,
|
|
"label2id": LABEL2ID,
|
|
"id2label": ID2LABEL,
|
|
"framework": "pytorch",
|
|
"export_format": "onnx",
|
|
}
|
|
metadata_path = onnx_dir / "model_metadata.json"
|
|
with open(metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
print(f" Metadata saved to: {metadata_path}")
|
|
return onnx_dir
|
|
|
|
|
|
# ── Main ───────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train DistilBERT SMS Spam Classifier")
|
|
parser.add_argument("--output-dir", type=str, default="./output", help="Output directory")
|
|
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="Number of training epochs")
|
|
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size")
|
|
parser.add_argument("--lr", type=float, default=LEARNING_RATE, help="Learning rate")
|
|
parser.add_argument("--quantize", action="store_true", help="Quantize to INT8")
|
|
parser.add_argument("--skip-training", action="store_true", help="Skip training, only export")
|
|
parser.add_argument("--augment", action="store_true", help="Include synthetic data augmentation")
|
|
args = parser.parse_args()
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print("=" * 60)
|
|
print("DistilBERT SMS Spam Classifier Training Pipeline")
|
|
print("=" * 60)
|
|
|
|
# Step 1: Load dataset
|
|
print("\n--- Step 1: Loading Dataset ---")
|
|
dataset_dict = load_dataset()
|
|
|
|
# Step 2: Optional synthetic augmentation
|
|
if args.augment:
|
|
print("\n--- Step 2: Data Augmentation ---")
|
|
synthetic = generate_synthetic_data(num_samples=1000)
|
|
synthetic_ds = Dataset.from_dict(synthetic)
|
|
# Cast label to match train dataset's label type
|
|
train_ds = dataset_dict["train"]
|
|
synthetic_ds = synthetic_ds.cast_column("label", train_ds.features["label"])
|
|
augmented = concatenate_datasets([train_ds, synthetic_ds])
|
|
dataset_dict["train"] = augmented.shuffle(seed=42)
|
|
print(f" Augmented train set: {len(dataset_dict['train'])} samples")
|
|
|
|
# Step 3: Train model
|
|
model = None
|
|
tokenizer = None
|
|
val_acc = 0.0
|
|
if not args.skip_training:
|
|
print("\n--- Step 3: Training ---")
|
|
model, tokenizer, val_acc = train_model(
|
|
train_dataset=dataset_dict["train"],
|
|
val_dataset=dataset_dict["val"],
|
|
output_dir=output_dir,
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
learning_rate=args.lr,
|
|
)
|
|
|
|
if val_acc < 0.95:
|
|
print(f"\n WARNING: Validation accuracy {val_acc:.4f} is below 0.95 target")
|
|
print(" Consider increasing epochs or adjusting hyperparameters")
|
|
else:
|
|
print("\n--- Step 3: Skipping training (--skip-training) ---")
|
|
# Load existing model for export
|
|
final_model_path = output_dir / "final_model"
|
|
if final_model_path.exists():
|
|
print(f" Loading model from {final_model_path}")
|
|
model = DistilBertForSequenceClassification.from_pretrained(
|
|
str(final_model_path), num_labels=2
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(str(final_model_path))
|
|
else:
|
|
print(f" No existing model found at {final_model_path}")
|
|
print(f" Loading base {MODEL_NAME} for export...")
|
|
model = DistilBertForSequenceClassification.from_pretrained(
|
|
MODEL_NAME, num_labels=2
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
# Step 4: Export to ONNX
|
|
print("\n--- Step 4: ONNX Export ---")
|
|
onnx_dir = export_onnx(model, tokenizer, output_dir, quantize=args.quantize)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Pipeline complete!")
|
|
print(f" ONNX model: {onnx_dir / 'model.onnx'}")
|
|
print(f" Tokenizer: {onnx_dir}")
|
|
print(f" Validation accuracy: {val_acc:.4f}")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|