Files
plant-disease-id/scripts/fine-tune-model.py
2026-06-08 16:42:04 -04:00

538 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
fine-tune-model.py
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
(93 diseases + healthy + unknown).
Pipeline:
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
3. Freeze backbone, train only the new classification head
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
5. Export to TF.js GraphModel format
6. Export to .keras for future retraining
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
"""
import json
import os
import sys
import shutil
from pathlib import Path
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
import numpy as np
import tensorflow as tf
import keras
from keras import layers, optimizers, regularizers
# ─── Constants ───────────────────────────────────────────────────────────────
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_PATH = (
PROJECT_ROOT
/ "public"
/ "models"
/ "plant-disease-classifier"
/ "best_mnv2_pv_original.keras"
)
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
IMG_SIZE = 160 # Model input size
BATCH_SIZE = 32
EPOCHS_HEAD = 15 # Train just the new head
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_FINETUNE = 1e-5
VALIDATION_SPLIT = 0.15
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
# ─── Class Mapping ───────────────────────────────────────────────────────────
def build_class_mapping():
"""
Build a dict mapping dataset directory names → model class indices.
Matches the ordering in labels.ts / diseases.json.
Index 0 = "healthy"
Index 1-93 = disease IDs (in diseases.json order)
Index 94 = "unknown" (no images — skip during training)
"""
with open(DISEASES_JSON) as f:
diseases = json.load(f)
mapping = {"healthy": 0}
for i, disease in enumerate(diseases):
mapping[disease["id"]] = i + 1 # Index 1-93
mapping["unknown"] = 94 # Not trained, but reserved
# Reverse mapping for predictions
index_to_class = {v: k for k, v in mapping.items()}
return mapping, index_to_class
def verify_dataset(mapping):
"""Find which classes have images and how many."""
available = {}
total = 0
for class_id, class_idx in mapping.items():
class_dir = DATASET_DIR / class_id
if not class_dir.exists():
continue
image_paths = sorted(class_dir.glob("*"))
image_paths = [
p
for p in image_paths
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
if image_paths:
available[class_id] = {"index": class_idx, "count": len(image_paths)}
total += len(image_paths)
return available, total
def print_dataset_summary(available, total):
"""Print a summary of what's available."""
print(f"\n{'' * 60}")
print("DATASET SUMMARY")
print(f"{'' * 60}")
print(f" Total images: {total}")
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
print(
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
)
# Count images per class
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
counts.sort(key=lambda x: x[1])
print("\n Images per class:")
for idx, class_id, count in counts:
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
if class_id == "healthy":
label += " ← 2× target"
print(label)
# Stats
class_counts = [v["count"] for v in available.values()]
if class_counts:
print(
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
)
print(f"{'' * 60}\n")
# ─── Data Loading ────────────────────────────────────────────────────────────
def load_dataset(mapping, available):
"""
Load images from the dataset directory.
Returns train/validation datasets with augmentation.
"""
# Build file paths and labels
file_paths = []
labels = []
for class_id, info in available.items():
class_dir = DATASET_DIR / class_id
images = sorted(class_dir.glob("*"))
images = [
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
for img_path in images:
file_paths.append(str(img_path))
labels.append(info["index"])
file_paths = np.array(file_paths)
labels = np.array(labels)
# Shuffle
indices = np.random.RandomState(42).permutation(len(file_paths))
file_paths = file_paths[indices]
labels = labels[indices]
# Split train/validation
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
train_paths, val_paths = file_paths[:split], file_paths[split:]
train_labels, val_labels = labels[:split], labels[split:]
print(f" Train: {len(train_paths)} images")
print(f" Val: {len(val_paths)} images")
# Parse function
def parse_image(image_path, label):
img = tf.io.read_file(image_path)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
img = tf.cast(img, tf.float32) / 255.0
# ImageNet normalization (matching training-time preprocessing)
mean = tf.constant([0.485, 0.456, 0.406])
std = tf.constant([0.229, 0.224, 0.225])
img = (img - mean) / std
return img, label
def augment(image, label):
"""Data augmentation for training set."""
# Random horizontal flip
image = tf.image.random_flip_left_right(image)
# Random rotation (±20°)
image = tf.image.random_flip_up_down(image)
# Random brightness
image = tf.image.random_brightness(image, 0.15)
# Random contrast
image = tf.image.random_contrast(image, 0.8, 1.2)
# Random saturation
image = tf.image.random_saturation(image, 0.8, 1.2)
# Random hue
image = tf.image.random_hue(image, 0.05)
# Random crop (after slightly scaling up)
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
# Clip to valid range after augmentations
image = tf.clip_by_value(image, -2.5, 2.5)
return image, label
# Create tf.data datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds
# ─── Model Building ──────────────────────────────────────────────────────────
def build_model():
"""
Load the PlantVillage model and replace the classification head
with a 95-class output.
"""
print(f"\nLoading base model from: {MODEL_PATH}")
if not MODEL_PATH.exists():
print(f"ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
base_model = keras.models.load_model(str(MODEL_PATH))
print(f" Base model loaded: {type(base_model).__name__}")
print(f" Input shape: {base_model.input_shape}")
print(f" Output shape: {base_model.output_shape}")
# Extract backbone — everything up to the GlobalAveragePooling2D
# The model structure is:
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
backbone_output = base_model.get_layer("global_average_pooling2d").output
print(" Using backbone output: global_average_pooling2d")
# Freeze all backbone layers initially
# (we'll unfreeze later for fine-tuning)
for layer in base_model.layers:
if layer.name != "dense": # We'll replace this anyway
layer.trainable = False
# Build new classification head
x = backbone_output
x = layers.Dropout(0.3, name="dropout_new")(x)
x = layers.Dense(
NUM_CLASSES,
activation="softmax",
name="dense_new",
kernel_regularizer=regularizers.l2(1e-4),
)(x)
# Create new model
model = keras.Model(
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
)
print(f" New model input: {model.input_shape}")
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
# Count trainable params
backbone_trainable = sum(
w.shape.num_elements()
for layer in base_model.layers
if layer.name != "dense"
for w in layer.trainable_weights
)
head_trainable = sum(
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
)
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
print(f" New head: {head_trainable:,} params (training)")
return model
# ─── Training ────────────────────────────────────────────────────────────────
def train_head(model, train_ds, val_ds):
"""Stage 1: Train only the new classification head."""
print(f"\n{'=' * 60}")
print("STAGE 1: Training classification head")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_HEAD}")
print(f" Learning rate: {LEARNING_RATE_HEAD}")
print(f" Batch size: {BATCH_SIZE}")
# Freeze all backbone layers
for layer in model.layers:
if layer.name != "dense_new":
layer.trainable = False
else:
layer.trainable = True
# Verify
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_HEAD,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
return history
def train_finetune(model, train_ds, val_ds):
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
print(f"\n{'=' * 60}")
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_FINETUNE}")
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
# Find the MobileNetV2 functional module
# The backbone is a Functional model inside the base model
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
# Unfreeze the last ~25 layers of the backbone
total_backbone_layers = len(mobilenet_layer.layers)
unfreeze_from = max(0, total_backbone_layers - 25)
print(
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
)
for i, layer in enumerate(mobilenet_layer.layers):
layer.trainable = i >= unfreeze_from
# Also unfreeze the new head
model.get_layer("dense_new").trainable = True
model.get_layer("dropout_new").trainable = True
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_FINETUNE,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-7,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
return history
# ─── Export ──────────────────────────────────────────────────────────────────
def export_models(model, class_mapping, index_to_class):
"""Export the trained model to .keras and TF.js formats."""
print(f"\n{'=' * 60}")
print("EXPORTING")
print(f"{'=' * 60}")
# 1. Save as .keras (for future retraining)
keras_path = OUTPUT_DIR / "model-finetuned.keras"
model.save(str(keras_path))
print(f" ✓ Saved .keras: {keras_path}")
# 2. Save class mapping alongside the model
mapping_path = OUTPUT_DIR / "class_mapping.json"
with open(mapping_path, "w") as f:
json.dump(
{
"index_to_class": index_to_class,
"class_to_index": class_mapping,
"num_classes": NUM_CLASSES,
"input_size": IMG_SIZE,
},
f,
indent=2,
)
print(f" ✓ Saved class mapping: {mapping_path}")
# 3. Export to TF.js format
tfjs_path = str(TFJS_OUTPUT)
if TFJS_OUTPUT.exists():
shutil.rmtree(tfjs_path)
try:
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, tfjs_path)
print(f" ✓ Saved TF.js: {tfjs_path}/")
for f in sorted(TFJS_OUTPUT.iterdir()):
size = f.stat().st_size
print(f" {f.name:<30s} {size:>10,} bytes")
except Exception as e:
print(f" ⚠ TF.js export failed: {e}")
print(
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
)
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
def cleanup_old_model():
"""Remove old model.json and shards from the directory."""
for f in OUTPUT_DIR.glob("model.json"):
print(f" Removing old: {f.name}")
f.unlink()
for f in OUTPUT_DIR.glob("group1-shard*"):
print(f" Removing old: {f.name}")
f.unlink()
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("PLANT DISEASE MODEL FINE-TUNER")
print("=" * 60)
# 1. Build class mapping
print("\n[1/5] Building class mapping...")
class_mapping, index_to_class = build_class_mapping()
print(
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
)
# 2. Verify dataset
print("\n[2/5] Verifying dataset...")
if not DATASET_DIR.exists():
print(f" ERROR: Dataset not found at {DATASET_DIR}")
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
sys.exit(1)
available, total = verify_dataset(class_mapping)
print_dataset_summary(available, total)
if total < 100:
print(f" WARNING: Only {total} images. Consider scraping more data.")
print(" Continue anyway? (y/n)")
# Continue regardless — user can decide
# 3. Load dataset
print("\n[3/5] Loading and augmenting dataset...")
train_ds, val_ds = load_dataset(class_mapping, available)
# 4. Build and train model
print("\n[4/5] Building model...")
model = build_model()
model.summary()
# Check if training should run
if total > 0:
train_head(model, train_ds, val_ds)
train_finetune(model, train_ds, val_ds)
# 5. Export
print("\n[5/5] Exporting models...")
cleanup_old_model()
export_models(model, class_mapping, index_to_class)
else:
print("\n Skipping training — no dataset available.")
sys.exit(1)
# ── Final Summary ────────────────────────────────────────────────────────
print(f"\n{'=' * 60}")
print("DONE! Model fine-tuned and exported.")
print(f"{'=' * 60}")
print("\nFiles created:")
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
print(f" {TFJS_OUTPUT / 'model.json'}")
print("\nTo update your app:")
print(" 1. Replace model files:")
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
print(" 2. Restart the dev server")
print(" 3. Test with: POST /api/identify")
print("\nNote: Update labels.ts if the class order changed.")
if __name__ == "__main__":
main()