538 lines
18 KiB
Python
538 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
fine-tune-model.py
|
||
|
||
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
|
||
(93 diseases + healthy + unknown).
|
||
|
||
Pipeline:
|
||
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
|
||
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
|
||
3. Freeze backbone, train only the new classification head
|
||
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
|
||
5. Export to TF.js GraphModel format
|
||
6. Export to .keras for future retraining
|
||
|
||
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
|
||
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import keras
|
||
from keras import layers, optimizers, regularizers
|
||
|
||
# ─── Constants ───────────────────────────────────────────────────────────────
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||
MODEL_PATH = (
|
||
PROJECT_ROOT
|
||
/ "public"
|
||
/ "models"
|
||
/ "plant-disease-classifier"
|
||
/ "best_mnv2_pv_original.keras"
|
||
)
|
||
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
|
||
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
|
||
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
|
||
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
|
||
|
||
IMG_SIZE = 160 # Model input size
|
||
BATCH_SIZE = 32
|
||
EPOCHS_HEAD = 15 # Train just the new head
|
||
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
|
||
LEARNING_RATE_HEAD = 1e-3
|
||
LEARNING_RATE_FINETUNE = 1e-5
|
||
VALIDATION_SPLIT = 0.15
|
||
|
||
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
|
||
|
||
# ─── Class Mapping ───────────────────────────────────────────────────────────
|
||
|
||
|
||
def build_class_mapping():
|
||
"""
|
||
Build a dict mapping dataset directory names → model class indices.
|
||
Matches the ordering in labels.ts / diseases.json.
|
||
|
||
Index 0 = "healthy"
|
||
Index 1-93 = disease IDs (in diseases.json order)
|
||
Index 94 = "unknown" (no images — skip during training)
|
||
"""
|
||
with open(DISEASES_JSON) as f:
|
||
diseases = json.load(f)
|
||
|
||
mapping = {"healthy": 0}
|
||
for i, disease in enumerate(diseases):
|
||
mapping[disease["id"]] = i + 1 # Index 1-93
|
||
mapping["unknown"] = 94 # Not trained, but reserved
|
||
|
||
# Reverse mapping for predictions
|
||
index_to_class = {v: k for k, v in mapping.items()}
|
||
|
||
return mapping, index_to_class
|
||
|
||
|
||
def verify_dataset(mapping):
|
||
"""Find which classes have images and how many."""
|
||
available = {}
|
||
total = 0
|
||
|
||
for class_id, class_idx in mapping.items():
|
||
class_dir = DATASET_DIR / class_id
|
||
if not class_dir.exists():
|
||
continue
|
||
|
||
image_paths = sorted(class_dir.glob("*"))
|
||
image_paths = [
|
||
p
|
||
for p in image_paths
|
||
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||
]
|
||
|
||
if image_paths:
|
||
available[class_id] = {"index": class_idx, "count": len(image_paths)}
|
||
total += len(image_paths)
|
||
|
||
return available, total
|
||
|
||
|
||
def print_dataset_summary(available, total):
|
||
"""Print a summary of what's available."""
|
||
print(f"\n{'─' * 60}")
|
||
print("DATASET SUMMARY")
|
||
print(f"{'─' * 60}")
|
||
print(f" Total images: {total}")
|
||
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
|
||
print(
|
||
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
|
||
)
|
||
|
||
# Count images per class
|
||
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
|
||
counts.sort(key=lambda x: x[1])
|
||
|
||
print("\n Images per class:")
|
||
for idx, class_id, count in counts:
|
||
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
|
||
if class_id == "healthy":
|
||
label += " ← 2× target"
|
||
print(label)
|
||
|
||
# Stats
|
||
class_counts = [v["count"] for v in available.values()]
|
||
if class_counts:
|
||
print(
|
||
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
|
||
)
|
||
print(f"{'─' * 60}\n")
|
||
|
||
|
||
# ─── Data Loading ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def load_dataset(mapping, available):
|
||
"""
|
||
Load images from the dataset directory.
|
||
Returns train/validation datasets with augmentation.
|
||
"""
|
||
# Build file paths and labels
|
||
file_paths = []
|
||
labels = []
|
||
|
||
for class_id, info in available.items():
|
||
class_dir = DATASET_DIR / class_id
|
||
images = sorted(class_dir.glob("*"))
|
||
images = [
|
||
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||
]
|
||
|
||
for img_path in images:
|
||
file_paths.append(str(img_path))
|
||
labels.append(info["index"])
|
||
|
||
file_paths = np.array(file_paths)
|
||
labels = np.array(labels)
|
||
|
||
# Shuffle
|
||
indices = np.random.RandomState(42).permutation(len(file_paths))
|
||
file_paths = file_paths[indices]
|
||
labels = labels[indices]
|
||
|
||
# Split train/validation
|
||
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
|
||
train_paths, val_paths = file_paths[:split], file_paths[split:]
|
||
train_labels, val_labels = labels[:split], labels[split:]
|
||
|
||
print(f" Train: {len(train_paths)} images")
|
||
print(f" Val: {len(val_paths)} images")
|
||
|
||
# Parse function
|
||
def parse_image(image_path, label):
|
||
img = tf.io.read_file(image_path)
|
||
img = tf.image.decode_image(img, channels=3, expand_animations=False)
|
||
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
|
||
img = tf.cast(img, tf.float32) / 255.0
|
||
# ImageNet normalization (matching training-time preprocessing)
|
||
mean = tf.constant([0.485, 0.456, 0.406])
|
||
std = tf.constant([0.229, 0.224, 0.225])
|
||
img = (img - mean) / std
|
||
return img, label
|
||
|
||
def augment(image, label):
|
||
"""Data augmentation for training set."""
|
||
# Random horizontal flip
|
||
image = tf.image.random_flip_left_right(image)
|
||
# Random rotation (±20°)
|
||
image = tf.image.random_flip_up_down(image)
|
||
# Random brightness
|
||
image = tf.image.random_brightness(image, 0.15)
|
||
# Random contrast
|
||
image = tf.image.random_contrast(image, 0.8, 1.2)
|
||
# Random saturation
|
||
image = tf.image.random_saturation(image, 0.8, 1.2)
|
||
# Random hue
|
||
image = tf.image.random_hue(image, 0.05)
|
||
# Random crop (after slightly scaling up)
|
||
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
|
||
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
|
||
# Clip to valid range after augmentations
|
||
image = tf.clip_by_value(image, -2.5, 2.5)
|
||
return image, label
|
||
|
||
# Create tf.data datasets
|
||
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
|
||
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
|
||
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||
|
||
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
|
||
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||
|
||
return train_ds, val_ds
|
||
|
||
|
||
# ─── Model Building ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def build_model():
|
||
"""
|
||
Load the PlantVillage model and replace the classification head
|
||
with a 95-class output.
|
||
"""
|
||
print(f"\nLoading base model from: {MODEL_PATH}")
|
||
if not MODEL_PATH.exists():
|
||
print(f"ERROR: Model not found at {MODEL_PATH}")
|
||
sys.exit(1)
|
||
|
||
base_model = keras.models.load_model(str(MODEL_PATH))
|
||
print(f" Base model loaded: {type(base_model).__name__}")
|
||
print(f" Input shape: {base_model.input_shape}")
|
||
print(f" Output shape: {base_model.output_shape}")
|
||
|
||
# Extract backbone — everything up to the GlobalAveragePooling2D
|
||
# The model structure is:
|
||
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
|
||
backbone_output = base_model.get_layer("global_average_pooling2d").output
|
||
print(" Using backbone output: global_average_pooling2d")
|
||
|
||
# Freeze all backbone layers initially
|
||
# (we'll unfreeze later for fine-tuning)
|
||
for layer in base_model.layers:
|
||
if layer.name != "dense": # We'll replace this anyway
|
||
layer.trainable = False
|
||
|
||
# Build new classification head
|
||
x = backbone_output
|
||
x = layers.Dropout(0.3, name="dropout_new")(x)
|
||
x = layers.Dense(
|
||
NUM_CLASSES,
|
||
activation="softmax",
|
||
name="dense_new",
|
||
kernel_regularizer=regularizers.l2(1e-4),
|
||
)(x)
|
||
|
||
# Create new model
|
||
model = keras.Model(
|
||
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
|
||
)
|
||
|
||
print(f" New model input: {model.input_shape}")
|
||
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
|
||
|
||
# Count trainable params
|
||
backbone_trainable = sum(
|
||
w.shape.num_elements()
|
||
for layer in base_model.layers
|
||
if layer.name != "dense"
|
||
for w in layer.trainable_weights
|
||
)
|
||
head_trainable = sum(
|
||
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
|
||
)
|
||
|
||
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
|
||
print(f" New head: {head_trainable:,} params (training)")
|
||
|
||
return model
|
||
|
||
|
||
# ─── Training ────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def train_head(model, train_ds, val_ds):
|
||
"""Stage 1: Train only the new classification head."""
|
||
print(f"\n{'=' * 60}")
|
||
print("STAGE 1: Training classification head")
|
||
print(f"{'=' * 60}")
|
||
print(f" Epochs: {EPOCHS_HEAD}")
|
||
print(f" Learning rate: {LEARNING_RATE_HEAD}")
|
||
print(f" Batch size: {BATCH_SIZE}")
|
||
|
||
# Freeze all backbone layers
|
||
for layer in model.layers:
|
||
if layer.name != "dense_new":
|
||
layer.trainable = False
|
||
else:
|
||
layer.trainable = True
|
||
|
||
# Verify
|
||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||
total = sum(w.shape.num_elements() for w in model.weights)
|
||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||
|
||
model.compile(
|
||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
|
||
loss="sparse_categorical_crossentropy",
|
||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||
)
|
||
|
||
history = model.fit(
|
||
train_ds,
|
||
validation_data=val_ds,
|
||
epochs=EPOCHS_HEAD,
|
||
verbose=1,
|
||
callbacks=[
|
||
keras.callbacks.EarlyStopping(
|
||
monitor="val_accuracy",
|
||
patience=3,
|
||
restore_best_weights=True,
|
||
),
|
||
keras.callbacks.ReduceLROnPlateau(
|
||
monitor="val_loss",
|
||
factor=0.5,
|
||
patience=2,
|
||
min_lr=1e-6,
|
||
),
|
||
],
|
||
)
|
||
|
||
final_val_acc = history.history["val_accuracy"][-1]
|
||
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
|
||
return history
|
||
|
||
|
||
def train_finetune(model, train_ds, val_ds):
|
||
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
|
||
print(f"\n{'=' * 60}")
|
||
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
|
||
print(f"{'=' * 60}")
|
||
print(f" Epochs: {EPOCHS_FINETUNE}")
|
||
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
|
||
|
||
# Find the MobileNetV2 functional module
|
||
# The backbone is a Functional model inside the base model
|
||
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
|
||
|
||
# Unfreeze the last ~25 layers of the backbone
|
||
total_backbone_layers = len(mobilenet_layer.layers)
|
||
unfreeze_from = max(0, total_backbone_layers - 25)
|
||
print(
|
||
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
|
||
)
|
||
|
||
for i, layer in enumerate(mobilenet_layer.layers):
|
||
layer.trainable = i >= unfreeze_from
|
||
|
||
# Also unfreeze the new head
|
||
model.get_layer("dense_new").trainable = True
|
||
model.get_layer("dropout_new").trainable = True
|
||
|
||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||
total = sum(w.shape.num_elements() for w in model.weights)
|
||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||
|
||
model.compile(
|
||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
|
||
loss="sparse_categorical_crossentropy",
|
||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||
)
|
||
|
||
history = model.fit(
|
||
train_ds,
|
||
validation_data=val_ds,
|
||
epochs=EPOCHS_FINETUNE,
|
||
verbose=1,
|
||
callbacks=[
|
||
keras.callbacks.EarlyStopping(
|
||
monitor="val_accuracy",
|
||
patience=3,
|
||
restore_best_weights=True,
|
||
),
|
||
keras.callbacks.ReduceLROnPlateau(
|
||
monitor="val_loss",
|
||
factor=0.5,
|
||
patience=2,
|
||
min_lr=1e-7,
|
||
),
|
||
],
|
||
)
|
||
|
||
final_val_acc = history.history["val_accuracy"][-1]
|
||
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
|
||
return history
|
||
|
||
|
||
# ─── Export ──────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def export_models(model, class_mapping, index_to_class):
|
||
"""Export the trained model to .keras and TF.js formats."""
|
||
print(f"\n{'=' * 60}")
|
||
print("EXPORTING")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 1. Save as .keras (for future retraining)
|
||
keras_path = OUTPUT_DIR / "model-finetuned.keras"
|
||
model.save(str(keras_path))
|
||
print(f" ✓ Saved .keras: {keras_path}")
|
||
|
||
# 2. Save class mapping alongside the model
|
||
mapping_path = OUTPUT_DIR / "class_mapping.json"
|
||
with open(mapping_path, "w") as f:
|
||
json.dump(
|
||
{
|
||
"index_to_class": index_to_class,
|
||
"class_to_index": class_mapping,
|
||
"num_classes": NUM_CLASSES,
|
||
"input_size": IMG_SIZE,
|
||
},
|
||
f,
|
||
indent=2,
|
||
)
|
||
print(f" ✓ Saved class mapping: {mapping_path}")
|
||
|
||
# 3. Export to TF.js format
|
||
tfjs_path = str(TFJS_OUTPUT)
|
||
if TFJS_OUTPUT.exists():
|
||
shutil.rmtree(tfjs_path)
|
||
|
||
try:
|
||
import tensorflowjs as tfjs
|
||
|
||
tfjs.converters.save_keras_model(model, tfjs_path)
|
||
print(f" ✓ Saved TF.js: {tfjs_path}/")
|
||
for f in sorted(TFJS_OUTPUT.iterdir()):
|
||
size = f.stat().st_size
|
||
print(f" {f.name:<30s} {size:>10,} bytes")
|
||
except Exception as e:
|
||
print(f" ⚠ TF.js export failed: {e}")
|
||
print(
|
||
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
|
||
)
|
||
|
||
|
||
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
|
||
|
||
|
||
def cleanup_old_model():
|
||
"""Remove old model.json and shards from the directory."""
|
||
for f in OUTPUT_DIR.glob("model.json"):
|
||
print(f" Removing old: {f.name}")
|
||
f.unlink()
|
||
for f in OUTPUT_DIR.glob("group1-shard*"):
|
||
print(f" Removing old: {f.name}")
|
||
f.unlink()
|
||
|
||
|
||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("PLANT DISEASE MODEL FINE-TUNER")
|
||
print("=" * 60)
|
||
|
||
# 1. Build class mapping
|
||
print("\n[1/5] Building class mapping...")
|
||
class_mapping, index_to_class = build_class_mapping()
|
||
print(
|
||
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
|
||
)
|
||
|
||
# 2. Verify dataset
|
||
print("\n[2/5] Verifying dataset...")
|
||
if not DATASET_DIR.exists():
|
||
print(f" ERROR: Dataset not found at {DATASET_DIR}")
|
||
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
|
||
sys.exit(1)
|
||
|
||
available, total = verify_dataset(class_mapping)
|
||
print_dataset_summary(available, total)
|
||
|
||
if total < 100:
|
||
print(f" WARNING: Only {total} images. Consider scraping more data.")
|
||
print(" Continue anyway? (y/n)")
|
||
# Continue regardless — user can decide
|
||
|
||
# 3. Load dataset
|
||
print("\n[3/5] Loading and augmenting dataset...")
|
||
train_ds, val_ds = load_dataset(class_mapping, available)
|
||
|
||
# 4. Build and train model
|
||
print("\n[4/5] Building model...")
|
||
model = build_model()
|
||
model.summary()
|
||
|
||
# Check if training should run
|
||
if total > 0:
|
||
train_head(model, train_ds, val_ds)
|
||
train_finetune(model, train_ds, val_ds)
|
||
|
||
# 5. Export
|
||
print("\n[5/5] Exporting models...")
|
||
cleanup_old_model()
|
||
export_models(model, class_mapping, index_to_class)
|
||
else:
|
||
print("\n Skipping training — no dataset available.")
|
||
sys.exit(1)
|
||
|
||
# ── Final Summary ────────────────────────────────────────────────────────
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print("DONE! Model fine-tuned and exported.")
|
||
print(f"{'=' * 60}")
|
||
print("\nFiles created:")
|
||
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
|
||
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
|
||
print(f" {TFJS_OUTPUT / 'model.json'}")
|
||
print("\nTo update your app:")
|
||
print(" 1. Replace model files:")
|
||
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
|
||
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
|
||
print(" 2. Restart the dev server")
|
||
print(" 3. Test with: POST /api/identify")
|
||
print("\nNote: Update labels.ts if the class order changed.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|