#!/usr/bin/env python3 """ fine-tune-model.py Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset (93 diseases + healthy + unknown). Pipeline: 1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head) 2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown) 3. Freeze backbone, train only the new classification head 4. Unfreeze the last ~20 layers, fine-tune at lower learning rate 5. Export to TF.js GraphModel format 6. Export to .keras for future retraining Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py """ import json import os import sys import shutil from pathlib import Path os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings import numpy as np import tensorflow as tf import keras from keras import layers, optimizers, regularizers # ─── Constants ─────────────────────────────────────────────────────────────── PROJECT_ROOT = Path(__file__).resolve().parent.parent MODEL_PATH = ( PROJECT_ROOT / "public" / "models" / "plant-disease-classifier" / "best_mnv2_pv_original.keras" ) DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json" DATASET_DIR = PROJECT_ROOT / "data" / "dataset" OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier" TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned" IMG_SIZE = 160 # Model input size BATCH_SIZE = 32 EPOCHS_HEAD = 15 # Train just the new head EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune LEARNING_RATE_HEAD = 1e-3 LEARNING_RATE_FINETUNE = 1e-5 VALIDATION_SPLIT = 0.15 NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94) # ─── Class Mapping ─────────────────────────────────────────────────────────── def build_class_mapping(): """ Build a dict mapping dataset directory names → model class indices. Matches the ordering in labels.ts / diseases.json. Index 0 = "healthy" Index 1-93 = disease IDs (in diseases.json order) Index 94 = "unknown" (no images — skip during training) """ with open(DISEASES_JSON) as f: diseases = json.load(f) mapping = {"healthy": 0} for i, disease in enumerate(diseases): mapping[disease["id"]] = i + 1 # Index 1-93 mapping["unknown"] = 94 # Not trained, but reserved # Reverse mapping for predictions index_to_class = {v: k for k, v in mapping.items()} return mapping, index_to_class def verify_dataset(mapping): """Find which classes have images and how many.""" available = {} total = 0 for class_id, class_idx in mapping.items(): class_dir = DATASET_DIR / class_id if not class_dir.exists(): continue image_paths = sorted(class_dir.glob("*")) image_paths = [ p for p in image_paths if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") ] if image_paths: available[class_id] = {"index": class_idx, "count": len(image_paths)} total += len(image_paths) return available, total def print_dataset_summary(available, total): """Print a summary of what's available.""" print(f"\n{'─' * 60}") print("DATASET SUMMARY") print(f"{'─' * 60}") print(f" Total images: {total}") print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}") print( f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}" ) # Count images per class counts = [(v["index"], k, v["count"]) for k, v in available.items()] counts.sort(key=lambda x: x[1]) print("\n Images per class:") for idx, class_id, count in counts: label = f" {idx:3d}. {class_id:<35s} {count:>4d} images" if class_id == "healthy": label += " ← 2× target" print(label) # Stats class_counts = [v["count"] for v in available.values()] if class_counts: print( f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}" ) print(f"{'─' * 60}\n") # ─── Data Loading ──────────────────────────────────────────────────────────── def load_dataset(mapping, available): """ Load images from the dataset directory. Returns train/validation datasets with augmentation. """ # Build file paths and labels file_paths = [] labels = [] for class_id, info in available.items(): class_dir = DATASET_DIR / class_id images = sorted(class_dir.glob("*")) images = [ p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") ] for img_path in images: file_paths.append(str(img_path)) labels.append(info["index"]) file_paths = np.array(file_paths) labels = np.array(labels) # Shuffle indices = np.random.RandomState(42).permutation(len(file_paths)) file_paths = file_paths[indices] labels = labels[indices] # Split train/validation split = int(len(file_paths) * (1 - VALIDATION_SPLIT)) train_paths, val_paths = file_paths[:split], file_paths[split:] train_labels, val_labels = labels[:split], labels[split:] print(f" Train: {len(train_paths)} images") print(f" Val: {len(val_paths)} images") # Parse function def parse_image(image_path, label): img = tf.io.read_file(image_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE]) img = tf.cast(img, tf.float32) / 255.0 # ImageNet normalization (matching training-time preprocessing) mean = tf.constant([0.485, 0.456, 0.406]) std = tf.constant([0.229, 0.224, 0.225]) img = (img - mean) / std return img, label def augment(image, label): """Data augmentation for training set.""" # Random horizontal flip image = tf.image.random_flip_left_right(image) # Random rotation (±20°) image = tf.image.random_flip_up_down(image) # Random brightness image = tf.image.random_brightness(image, 0.15) # Random contrast image = tf.image.random_contrast(image, 0.8, 1.2) # Random saturation image = tf.image.random_saturation(image, 0.8, 1.2) # Random hue image = tf.image.random_hue(image, 0.05) # Random crop (after slightly scaling up) image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12) image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE]) # Clip to valid range after augmentations image = tf.clip_by_value(image, -2.5, 2.5) return image, label # Create tf.data datasets train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels)) train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE) train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels)) val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) return train_ds, val_ds # ─── Model Building ────────────────────────────────────────────────────────── def build_model(): """ Load the PlantVillage model and replace the classification head with a 95-class output. """ print(f"\nLoading base model from: {MODEL_PATH}") if not MODEL_PATH.exists(): print(f"ERROR: Model not found at {MODEL_PATH}") sys.exit(1) base_model = keras.models.load_model(str(MODEL_PATH)) print(f" Base model loaded: {type(base_model).__name__}") print(f" Input shape: {base_model.input_shape}") print(f" Output shape: {base_model.output_shape}") # Extract backbone — everything up to the GlobalAveragePooling2D # The model structure is: # input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38) backbone_output = base_model.get_layer("global_average_pooling2d").output print(" Using backbone output: global_average_pooling2d") # Freeze all backbone layers initially # (we'll unfreeze later for fine-tuning) for layer in base_model.layers: if layer.name != "dense": # We'll replace this anyway layer.trainable = False # Build new classification head x = backbone_output x = layers.Dropout(0.3, name="dropout_new")(x) x = layers.Dense( NUM_CLASSES, activation="softmax", name="dense_new", kernel_regularizer=regularizers.l2(1e-4), )(x) # Create new model model = keras.Model( inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2" ) print(f" New model input: {model.input_shape}") print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)") # Count trainable params backbone_trainable = sum( w.shape.num_elements() for layer in base_model.layers if layer.name != "dense" for w in layer.trainable_weights ) head_trainable = sum( w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights ) print(f" Backbone frozen: {backbone_trainable:,} params (not training)") print(f" New head: {head_trainable:,} params (training)") return model # ─── Training ──────────────────────────────────────────────────────────────── def train_head(model, train_ds, val_ds): """Stage 1: Train only the new classification head.""" print(f"\n{'=' * 60}") print("STAGE 1: Training classification head") print(f"{'=' * 60}") print(f" Epochs: {EPOCHS_HEAD}") print(f" Learning rate: {LEARNING_RATE_HEAD}") print(f" Batch size: {BATCH_SIZE}") # Freeze all backbone layers for layer in model.layers: if layer.name != "dense_new": layer.trainable = False else: layer.trainable = True # Verify trainable = sum(w.shape.num_elements() for w in model.trainable_weights) total = sum(w.shape.num_elements() for w in model.weights) print(f" Trainable params: {trainable:,} / {total:,} total") model.compile( optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD), loss="sparse_categorical_crossentropy", metrics=["accuracy", "sparse_top_k_categorical_accuracy"], ) history = model.fit( train_ds, validation_data=val_ds, epochs=EPOCHS_HEAD, verbose=1, callbacks=[ keras.callbacks.EarlyStopping( monitor="val_accuracy", patience=3, restore_best_weights=True, ), keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6, ), ], ) final_val_acc = history.history["val_accuracy"][-1] print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}") return history def train_finetune(model, train_ds, val_ds): """Stage 2: Unfreeze last ~25 layers and fine-tune.""" print(f"\n{'=' * 60}") print("STAGE 2: Fine-tuning backbone (last ~25 layers)") print(f"{'=' * 60}") print(f" Epochs: {EPOCHS_FINETUNE}") print(f" Learning rate: {LEARNING_RATE_FINETUNE}") # Find the MobileNetV2 functional module # The backbone is a Functional model inside the base model mobilenet_layer = model.get_layer("mobilenetv2_1.00_160") # Unfreeze the last ~25 layers of the backbone total_backbone_layers = len(mobilenet_layer.layers) unfreeze_from = max(0, total_backbone_layers - 25) print( f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}" ) for i, layer in enumerate(mobilenet_layer.layers): layer.trainable = i >= unfreeze_from # Also unfreeze the new head model.get_layer("dense_new").trainable = True model.get_layer("dropout_new").trainable = True trainable = sum(w.shape.num_elements() for w in model.trainable_weights) total = sum(w.shape.num_elements() for w in model.weights) print(f" Trainable params: {trainable:,} / {total:,} total") model.compile( optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE), loss="sparse_categorical_crossentropy", metrics=["accuracy", "sparse_top_k_categorical_accuracy"], ) history = model.fit( train_ds, validation_data=val_ds, epochs=EPOCHS_FINETUNE, verbose=1, callbacks=[ keras.callbacks.EarlyStopping( monitor="val_accuracy", patience=3, restore_best_weights=True, ), keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=2, min_lr=1e-7, ), ], ) final_val_acc = history.history["val_accuracy"][-1] print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}") return history # ─── Export ────────────────────────────────────────────────────────────────── def export_models(model, class_mapping, index_to_class): """Export the trained model to .keras and TF.js formats.""" print(f"\n{'=' * 60}") print("EXPORTING") print(f"{'=' * 60}") # 1. Save as .keras (for future retraining) keras_path = OUTPUT_DIR / "model-finetuned.keras" model.save(str(keras_path)) print(f" ✓ Saved .keras: {keras_path}") # 2. Save class mapping alongside the model mapping_path = OUTPUT_DIR / "class_mapping.json" with open(mapping_path, "w") as f: json.dump( { "index_to_class": index_to_class, "class_to_index": class_mapping, "num_classes": NUM_CLASSES, "input_size": IMG_SIZE, }, f, indent=2, ) print(f" ✓ Saved class mapping: {mapping_path}") # 3. Export to TF.js format tfjs_path = str(TFJS_OUTPUT) if TFJS_OUTPUT.exists(): shutil.rmtree(tfjs_path) try: import tensorflowjs as tfjs tfjs.converters.save_keras_model(model, tfjs_path) print(f" ✓ Saved TF.js: {tfjs_path}/") for f in sorted(TFJS_OUTPUT.iterdir()): size = f.stat().st_size print(f" {f.name:<30s} {size:>10,} bytes") except Exception as e: print(f" ⚠ TF.js export failed: {e}") print( f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}" ) # ─── Cleanup Old Model Files ──────────────────────────────────────────────── def cleanup_old_model(): """Remove old model.json and shards from the directory.""" for f in OUTPUT_DIR.glob("model.json"): print(f" Removing old: {f.name}") f.unlink() for f in OUTPUT_DIR.glob("group1-shard*"): print(f" Removing old: {f.name}") f.unlink() # ─── Main ──────────────────────────────────────────────────────────────────── def main(): print("=" * 60) print("PLANT DISEASE MODEL FINE-TUNER") print("=" * 60) # 1. Build class mapping print("\n[1/5] Building class mapping...") class_mapping, index_to_class = build_class_mapping() print( f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)" ) # 2. Verify dataset print("\n[2/5] Verifying dataset...") if not DATASET_DIR.exists(): print(f" ERROR: Dataset not found at {DATASET_DIR}") print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts") sys.exit(1) available, total = verify_dataset(class_mapping) print_dataset_summary(available, total) if total < 100: print(f" WARNING: Only {total} images. Consider scraping more data.") print(" Continue anyway? (y/n)") # Continue regardless — user can decide # 3. Load dataset print("\n[3/5] Loading and augmenting dataset...") train_ds, val_ds = load_dataset(class_mapping, available) # 4. Build and train model print("\n[4/5] Building model...") model = build_model() model.summary() # Check if training should run if total > 0: train_head(model, train_ds, val_ds) train_finetune(model, train_ds, val_ds) # 5. Export print("\n[5/5] Exporting models...") cleanup_old_model() export_models(model, class_mapping, index_to_class) else: print("\n Skipping training — no dataset available.") sys.exit(1) # ── Final Summary ──────────────────────────────────────────────────────── print(f"\n{'=' * 60}") print("DONE! Model fine-tuned and exported.") print(f"{'=' * 60}") print("\nFiles created:") print(f" {OUTPUT_DIR / 'model-finetuned.keras'}") print(f" {OUTPUT_DIR / 'class_mapping.json'}") print(f" {TFJS_OUTPUT / 'model.json'}") print("\nTo update your app:") print(" 1. Replace model files:") print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}") print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}") print(" 2. Restart the dev server") print(" 3. Test with: POST /api/identify") print("\nNote: Update labels.ts if the class order changed.") if __name__ == "__main__": main()