re-init
This commit is contained in:
537
scripts/fine-tune-model.py
Normal file
537
scripts/fine-tune-model.py
Normal file
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
fine-tune-model.py
|
||||
|
||||
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
|
||||
(93 diseases + healthy + unknown).
|
||||
|
||||
Pipeline:
|
||||
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
|
||||
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
|
||||
3. Freeze backbone, train only the new classification head
|
||||
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
|
||||
5. Export to TF.js GraphModel format
|
||||
6. Export to .keras for future retraining
|
||||
|
||||
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras import layers, optimizers, regularizers
|
||||
|
||||
# ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_PATH = (
|
||||
PROJECT_ROOT
|
||||
/ "public"
|
||||
/ "models"
|
||||
/ "plant-disease-classifier"
|
||||
/ "best_mnv2_pv_original.keras"
|
||||
)
|
||||
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
|
||||
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
|
||||
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
|
||||
|
||||
IMG_SIZE = 160 # Model input size
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_HEAD = 15 # Train just the new head
|
||||
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
|
||||
LEARNING_RATE_HEAD = 1e-3
|
||||
LEARNING_RATE_FINETUNE = 1e-5
|
||||
VALIDATION_SPLIT = 0.15
|
||||
|
||||
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
|
||||
|
||||
# ─── Class Mapping ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_class_mapping():
|
||||
"""
|
||||
Build a dict mapping dataset directory names → model class indices.
|
||||
Matches the ordering in labels.ts / diseases.json.
|
||||
|
||||
Index 0 = "healthy"
|
||||
Index 1-93 = disease IDs (in diseases.json order)
|
||||
Index 94 = "unknown" (no images — skip during training)
|
||||
"""
|
||||
with open(DISEASES_JSON) as f:
|
||||
diseases = json.load(f)
|
||||
|
||||
mapping = {"healthy": 0}
|
||||
for i, disease in enumerate(diseases):
|
||||
mapping[disease["id"]] = i + 1 # Index 1-93
|
||||
mapping["unknown"] = 94 # Not trained, but reserved
|
||||
|
||||
# Reverse mapping for predictions
|
||||
index_to_class = {v: k for k, v in mapping.items()}
|
||||
|
||||
return mapping, index_to_class
|
||||
|
||||
|
||||
def verify_dataset(mapping):
|
||||
"""Find which classes have images and how many."""
|
||||
available = {}
|
||||
total = 0
|
||||
|
||||
for class_id, class_idx in mapping.items():
|
||||
class_dir = DATASET_DIR / class_id
|
||||
if not class_dir.exists():
|
||||
continue
|
||||
|
||||
image_paths = sorted(class_dir.glob("*"))
|
||||
image_paths = [
|
||||
p
|
||||
for p in image_paths
|
||||
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||||
]
|
||||
|
||||
if image_paths:
|
||||
available[class_id] = {"index": class_idx, "count": len(image_paths)}
|
||||
total += len(image_paths)
|
||||
|
||||
return available, total
|
||||
|
||||
|
||||
def print_dataset_summary(available, total):
|
||||
"""Print a summary of what's available."""
|
||||
print(f"\n{'─' * 60}")
|
||||
print("DATASET SUMMARY")
|
||||
print(f"{'─' * 60}")
|
||||
print(f" Total images: {total}")
|
||||
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
|
||||
print(
|
||||
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
|
||||
)
|
||||
|
||||
# Count images per class
|
||||
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
|
||||
counts.sort(key=lambda x: x[1])
|
||||
|
||||
print("\n Images per class:")
|
||||
for idx, class_id, count in counts:
|
||||
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
|
||||
if class_id == "healthy":
|
||||
label += " ← 2× target"
|
||||
print(label)
|
||||
|
||||
# Stats
|
||||
class_counts = [v["count"] for v in available.values()]
|
||||
if class_counts:
|
||||
print(
|
||||
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
|
||||
)
|
||||
print(f"{'─' * 60}\n")
|
||||
|
||||
|
||||
# ─── Data Loading ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_dataset(mapping, available):
|
||||
"""
|
||||
Load images from the dataset directory.
|
||||
Returns train/validation datasets with augmentation.
|
||||
"""
|
||||
# Build file paths and labels
|
||||
file_paths = []
|
||||
labels = []
|
||||
|
||||
for class_id, info in available.items():
|
||||
class_dir = DATASET_DIR / class_id
|
||||
images = sorted(class_dir.glob("*"))
|
||||
images = [
|
||||
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||||
]
|
||||
|
||||
for img_path in images:
|
||||
file_paths.append(str(img_path))
|
||||
labels.append(info["index"])
|
||||
|
||||
file_paths = np.array(file_paths)
|
||||
labels = np.array(labels)
|
||||
|
||||
# Shuffle
|
||||
indices = np.random.RandomState(42).permutation(len(file_paths))
|
||||
file_paths = file_paths[indices]
|
||||
labels = labels[indices]
|
||||
|
||||
# Split train/validation
|
||||
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
|
||||
train_paths, val_paths = file_paths[:split], file_paths[split:]
|
||||
train_labels, val_labels = labels[:split], labels[split:]
|
||||
|
||||
print(f" Train: {len(train_paths)} images")
|
||||
print(f" Val: {len(val_paths)} images")
|
||||
|
||||
# Parse function
|
||||
def parse_image(image_path, label):
|
||||
img = tf.io.read_file(image_path)
|
||||
img = tf.image.decode_image(img, channels=3, expand_animations=False)
|
||||
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
|
||||
img = tf.cast(img, tf.float32) / 255.0
|
||||
# ImageNet normalization (matching training-time preprocessing)
|
||||
mean = tf.constant([0.485, 0.456, 0.406])
|
||||
std = tf.constant([0.229, 0.224, 0.225])
|
||||
img = (img - mean) / std
|
||||
return img, label
|
||||
|
||||
def augment(image, label):
|
||||
"""Data augmentation for training set."""
|
||||
# Random horizontal flip
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
# Random rotation (±20°)
|
||||
image = tf.image.random_flip_up_down(image)
|
||||
# Random brightness
|
||||
image = tf.image.random_brightness(image, 0.15)
|
||||
# Random contrast
|
||||
image = tf.image.random_contrast(image, 0.8, 1.2)
|
||||
# Random saturation
|
||||
image = tf.image.random_saturation(image, 0.8, 1.2)
|
||||
# Random hue
|
||||
image = tf.image.random_hue(image, 0.05)
|
||||
# Random crop (after slightly scaling up)
|
||||
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
|
||||
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
|
||||
# Clip to valid range after augmentations
|
||||
image = tf.clip_by_value(image, -2.5, 2.5)
|
||||
return image, label
|
||||
|
||||
# Create tf.data datasets
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
|
||||
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
|
||||
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return train_ds, val_ds
|
||||
|
||||
|
||||
# ─── Model Building ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_model():
|
||||
"""
|
||||
Load the PlantVillage model and replace the classification head
|
||||
with a 95-class output.
|
||||
"""
|
||||
print(f"\nLoading base model from: {MODEL_PATH}")
|
||||
if not MODEL_PATH.exists():
|
||||
print(f"ERROR: Model not found at {MODEL_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
base_model = keras.models.load_model(str(MODEL_PATH))
|
||||
print(f" Base model loaded: {type(base_model).__name__}")
|
||||
print(f" Input shape: {base_model.input_shape}")
|
||||
print(f" Output shape: {base_model.output_shape}")
|
||||
|
||||
# Extract backbone — everything up to the GlobalAveragePooling2D
|
||||
# The model structure is:
|
||||
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
|
||||
backbone_output = base_model.get_layer("global_average_pooling2d").output
|
||||
print(" Using backbone output: global_average_pooling2d")
|
||||
|
||||
# Freeze all backbone layers initially
|
||||
# (we'll unfreeze later for fine-tuning)
|
||||
for layer in base_model.layers:
|
||||
if layer.name != "dense": # We'll replace this anyway
|
||||
layer.trainable = False
|
||||
|
||||
# Build new classification head
|
||||
x = backbone_output
|
||||
x = layers.Dropout(0.3, name="dropout_new")(x)
|
||||
x = layers.Dense(
|
||||
NUM_CLASSES,
|
||||
activation="softmax",
|
||||
name="dense_new",
|
||||
kernel_regularizer=regularizers.l2(1e-4),
|
||||
)(x)
|
||||
|
||||
# Create new model
|
||||
model = keras.Model(
|
||||
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
|
||||
)
|
||||
|
||||
print(f" New model input: {model.input_shape}")
|
||||
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
|
||||
|
||||
# Count trainable params
|
||||
backbone_trainable = sum(
|
||||
w.shape.num_elements()
|
||||
for layer in base_model.layers
|
||||
if layer.name != "dense"
|
||||
for w in layer.trainable_weights
|
||||
)
|
||||
head_trainable = sum(
|
||||
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
|
||||
)
|
||||
|
||||
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
|
||||
print(f" New head: {head_trainable:,} params (training)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ─── Training ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def train_head(model, train_ds, val_ds):
|
||||
"""Stage 1: Train only the new classification head."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STAGE 1: Training classification head")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Epochs: {EPOCHS_HEAD}")
|
||||
print(f" Learning rate: {LEARNING_RATE_HEAD}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
|
||||
# Freeze all backbone layers
|
||||
for layer in model.layers:
|
||||
if layer.name != "dense_new":
|
||||
layer.trainable = False
|
||||
else:
|
||||
layer.trainable = True
|
||||
|
||||
# Verify
|
||||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||||
total = sum(w.shape.num_elements() for w in model.weights)
|
||||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||||
)
|
||||
|
||||
history = model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS_HEAD,
|
||||
verbose=1,
|
||||
callbacks=[
|
||||
keras.callbacks.EarlyStopping(
|
||||
monitor="val_accuracy",
|
||||
patience=3,
|
||||
restore_best_weights=True,
|
||||
),
|
||||
keras.callbacks.ReduceLROnPlateau(
|
||||
monitor="val_loss",
|
||||
factor=0.5,
|
||||
patience=2,
|
||||
min_lr=1e-6,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
final_val_acc = history.history["val_accuracy"][-1]
|
||||
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
|
||||
return history
|
||||
|
||||
|
||||
def train_finetune(model, train_ds, val_ds):
|
||||
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Epochs: {EPOCHS_FINETUNE}")
|
||||
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
|
||||
|
||||
# Find the MobileNetV2 functional module
|
||||
# The backbone is a Functional model inside the base model
|
||||
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
|
||||
|
||||
# Unfreeze the last ~25 layers of the backbone
|
||||
total_backbone_layers = len(mobilenet_layer.layers)
|
||||
unfreeze_from = max(0, total_backbone_layers - 25)
|
||||
print(
|
||||
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
|
||||
)
|
||||
|
||||
for i, layer in enumerate(mobilenet_layer.layers):
|
||||
layer.trainable = i >= unfreeze_from
|
||||
|
||||
# Also unfreeze the new head
|
||||
model.get_layer("dense_new").trainable = True
|
||||
model.get_layer("dropout_new").trainable = True
|
||||
|
||||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||||
total = sum(w.shape.num_elements() for w in model.weights)
|
||||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||||
)
|
||||
|
||||
history = model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS_FINETUNE,
|
||||
verbose=1,
|
||||
callbacks=[
|
||||
keras.callbacks.EarlyStopping(
|
||||
monitor="val_accuracy",
|
||||
patience=3,
|
||||
restore_best_weights=True,
|
||||
),
|
||||
keras.callbacks.ReduceLROnPlateau(
|
||||
monitor="val_loss",
|
||||
factor=0.5,
|
||||
patience=2,
|
||||
min_lr=1e-7,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
final_val_acc = history.history["val_accuracy"][-1]
|
||||
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
|
||||
return history
|
||||
|
||||
|
||||
# ─── Export ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def export_models(model, class_mapping, index_to_class):
|
||||
"""Export the trained model to .keras and TF.js formats."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("EXPORTING")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Save as .keras (for future retraining)
|
||||
keras_path = OUTPUT_DIR / "model-finetuned.keras"
|
||||
model.save(str(keras_path))
|
||||
print(f" ✓ Saved .keras: {keras_path}")
|
||||
|
||||
# 2. Save class mapping alongside the model
|
||||
mapping_path = OUTPUT_DIR / "class_mapping.json"
|
||||
with open(mapping_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"index_to_class": index_to_class,
|
||||
"class_to_index": class_mapping,
|
||||
"num_classes": NUM_CLASSES,
|
||||
"input_size": IMG_SIZE,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f" ✓ Saved class mapping: {mapping_path}")
|
||||
|
||||
# 3. Export to TF.js format
|
||||
tfjs_path = str(TFJS_OUTPUT)
|
||||
if TFJS_OUTPUT.exists():
|
||||
shutil.rmtree(tfjs_path)
|
||||
|
||||
try:
|
||||
import tensorflowjs as tfjs
|
||||
|
||||
tfjs.converters.save_keras_model(model, tfjs_path)
|
||||
print(f" ✓ Saved TF.js: {tfjs_path}/")
|
||||
for f in sorted(TFJS_OUTPUT.iterdir()):
|
||||
size = f.stat().st_size
|
||||
print(f" {f.name:<30s} {size:>10,} bytes")
|
||||
except Exception as e:
|
||||
print(f" ⚠ TF.js export failed: {e}")
|
||||
print(
|
||||
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
|
||||
)
|
||||
|
||||
|
||||
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def cleanup_old_model():
|
||||
"""Remove old model.json and shards from the directory."""
|
||||
for f in OUTPUT_DIR.glob("model.json"):
|
||||
print(f" Removing old: {f.name}")
|
||||
f.unlink()
|
||||
for f in OUTPUT_DIR.glob("group1-shard*"):
|
||||
print(f" Removing old: {f.name}")
|
||||
f.unlink()
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("PLANT DISEASE MODEL FINE-TUNER")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Build class mapping
|
||||
print("\n[1/5] Building class mapping...")
|
||||
class_mapping, index_to_class = build_class_mapping()
|
||||
print(
|
||||
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
|
||||
)
|
||||
|
||||
# 2. Verify dataset
|
||||
print("\n[2/5] Verifying dataset...")
|
||||
if not DATASET_DIR.exists():
|
||||
print(f" ERROR: Dataset not found at {DATASET_DIR}")
|
||||
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
|
||||
sys.exit(1)
|
||||
|
||||
available, total = verify_dataset(class_mapping)
|
||||
print_dataset_summary(available, total)
|
||||
|
||||
if total < 100:
|
||||
print(f" WARNING: Only {total} images. Consider scraping more data.")
|
||||
print(" Continue anyway? (y/n)")
|
||||
# Continue regardless — user can decide
|
||||
|
||||
# 3. Load dataset
|
||||
print("\n[3/5] Loading and augmenting dataset...")
|
||||
train_ds, val_ds = load_dataset(class_mapping, available)
|
||||
|
||||
# 4. Build and train model
|
||||
print("\n[4/5] Building model...")
|
||||
model = build_model()
|
||||
model.summary()
|
||||
|
||||
# Check if training should run
|
||||
if total > 0:
|
||||
train_head(model, train_ds, val_ds)
|
||||
train_finetune(model, train_ds, val_ds)
|
||||
|
||||
# 5. Export
|
||||
print("\n[5/5] Exporting models...")
|
||||
cleanup_old_model()
|
||||
export_models(model, class_mapping, index_to_class)
|
||||
else:
|
||||
print("\n Skipping training — no dataset available.")
|
||||
sys.exit(1)
|
||||
|
||||
# ── Final Summary ────────────────────────────────────────────────────────
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE! Model fine-tuned and exported.")
|
||||
print(f"{'=' * 60}")
|
||||
print("\nFiles created:")
|
||||
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
|
||||
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
|
||||
print(f" {TFJS_OUTPUT / 'model.json'}")
|
||||
print("\nTo update your app:")
|
||||
print(" 1. Replace model files:")
|
||||
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
|
||||
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
|
||||
print(" 2. Restart the dev server")
|
||||
print(" 3. Test with: POST /api/identify")
|
||||
print("\nNote: Update labels.ts if the class order changed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user