scripting

This commit is contained in:
2026-06-06 15:45:21 -04:00
parent 06295c83ca
commit 47609e5e42
11 changed files with 4411 additions and 205 deletions

View File

@@ -0,0 +1,537 @@
#!/usr/bin/env python3
"""
fine-tune-model.py
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
(93 diseases + healthy + unknown).
Pipeline:
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
3. Freeze backbone, train only the new classification head
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
5. Export to TF.js GraphModel format
6. Export to .keras for future retraining
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
"""
import json
import os
import sys
import shutil
from pathlib import Path
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
import numpy as np
import tensorflow as tf
import keras
from keras import layers, optimizers, regularizers
# ─── Constants ───────────────────────────────────────────────────────────────
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_PATH = (
PROJECT_ROOT
/ "public"
/ "models"
/ "plant-disease-classifier"
/ "best_mnv2_pv_original.keras"
)
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
IMG_SIZE = 160 # Model input size
BATCH_SIZE = 32
EPOCHS_HEAD = 15 # Train just the new head
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_FINETUNE = 1e-5
VALIDATION_SPLIT = 0.15
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
# ─── Class Mapping ───────────────────────────────────────────────────────────
def build_class_mapping():
"""
Build a dict mapping dataset directory names → model class indices.
Matches the ordering in labels.ts / diseases.json.
Index 0 = "healthy"
Index 1-93 = disease IDs (in diseases.json order)
Index 94 = "unknown" (no images — skip during training)
"""
with open(DISEASES_JSON) as f:
diseases = json.load(f)
mapping = {"healthy": 0}
for i, disease in enumerate(diseases):
mapping[disease["id"]] = i + 1 # Index 1-93
mapping["unknown"] = 94 # Not trained, but reserved
# Reverse mapping for predictions
index_to_class = {v: k for k, v in mapping.items()}
return mapping, index_to_class
def verify_dataset(mapping):
"""Find which classes have images and how many."""
available = {}
total = 0
for class_id, class_idx in mapping.items():
class_dir = DATASET_DIR / class_id
if not class_dir.exists():
continue
image_paths = sorted(class_dir.glob("*"))
image_paths = [
p
for p in image_paths
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
if image_paths:
available[class_id] = {"index": class_idx, "count": len(image_paths)}
total += len(image_paths)
return available, total
def print_dataset_summary(available, total):
"""Print a summary of what's available."""
print(f"\n{'' * 60}")
print("DATASET SUMMARY")
print(f"{'' * 60}")
print(f" Total images: {total}")
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
print(
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
)
# Count images per class
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
counts.sort(key=lambda x: x[1])
print("\n Images per class:")
for idx, class_id, count in counts:
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
if class_id == "healthy":
label += " ← 2× target"
print(label)
# Stats
class_counts = [v["count"] for v in available.values()]
if class_counts:
print(
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
)
print(f"{'' * 60}\n")
# ─── Data Loading ────────────────────────────────────────────────────────────
def load_dataset(mapping, available):
"""
Load images from the dataset directory.
Returns train/validation datasets with augmentation.
"""
# Build file paths and labels
file_paths = []
labels = []
for class_id, info in available.items():
class_dir = DATASET_DIR / class_id
images = sorted(class_dir.glob("*"))
images = [
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
for img_path in images:
file_paths.append(str(img_path))
labels.append(info["index"])
file_paths = np.array(file_paths)
labels = np.array(labels)
# Shuffle
indices = np.random.RandomState(42).permutation(len(file_paths))
file_paths = file_paths[indices]
labels = labels[indices]
# Split train/validation
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
train_paths, val_paths = file_paths[:split], file_paths[split:]
train_labels, val_labels = labels[:split], labels[split:]
print(f" Train: {len(train_paths)} images")
print(f" Val: {len(val_paths)} images")
# Parse function
def parse_image(image_path, label):
img = tf.io.read_file(image_path)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
img = tf.cast(img, tf.float32) / 255.0
# ImageNet normalization (matching training-time preprocessing)
mean = tf.constant([0.485, 0.456, 0.406])
std = tf.constant([0.229, 0.224, 0.225])
img = (img - mean) / std
return img, label
def augment(image, label):
"""Data augmentation for training set."""
# Random horizontal flip
image = tf.image.random_flip_left_right(image)
# Random rotation (±20°)
image = tf.image.random_flip_up_down(image)
# Random brightness
image = tf.image.random_brightness(image, 0.15)
# Random contrast
image = tf.image.random_contrast(image, 0.8, 1.2)
# Random saturation
image = tf.image.random_saturation(image, 0.8, 1.2)
# Random hue
image = tf.image.random_hue(image, 0.05)
# Random crop (after slightly scaling up)
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
# Clip to valid range after augmentations
image = tf.clip_by_value(image, -2.5, 2.5)
return image, label
# Create tf.data datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds
# ─── Model Building ──────────────────────────────────────────────────────────
def build_model():
"""
Load the PlantVillage model and replace the classification head
with a 95-class output.
"""
print(f"\nLoading base model from: {MODEL_PATH}")
if not MODEL_PATH.exists():
print(f"ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
base_model = keras.models.load_model(str(MODEL_PATH))
print(f" Base model loaded: {type(base_model).__name__}")
print(f" Input shape: {base_model.input_shape}")
print(f" Output shape: {base_model.output_shape}")
# Extract backbone — everything up to the GlobalAveragePooling2D
# The model structure is:
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
backbone_output = base_model.get_layer("global_average_pooling2d").output
print(" Using backbone output: global_average_pooling2d")
# Freeze all backbone layers initially
# (we'll unfreeze later for fine-tuning)
for layer in base_model.layers:
if layer.name != "dense": # We'll replace this anyway
layer.trainable = False
# Build new classification head
x = backbone_output
x = layers.Dropout(0.3, name="dropout_new")(x)
x = layers.Dense(
NUM_CLASSES,
activation="softmax",
name="dense_new",
kernel_regularizer=regularizers.l2(1e-4),
)(x)
# Create new model
model = keras.Model(
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
)
print(f" New model input: {model.input_shape}")
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
# Count trainable params
backbone_trainable = sum(
w.shape.num_elements()
for layer in base_model.layers
if layer.name != "dense"
for w in layer.trainable_weights
)
head_trainable = sum(
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
)
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
print(f" New head: {head_trainable:,} params (training)")
return model
# ─── Training ────────────────────────────────────────────────────────────────
def train_head(model, train_ds, val_ds):
"""Stage 1: Train only the new classification head."""
print(f"\n{'=' * 60}")
print("STAGE 1: Training classification head")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_HEAD}")
print(f" Learning rate: {LEARNING_RATE_HEAD}")
print(f" Batch size: {BATCH_SIZE}")
# Freeze all backbone layers
for layer in model.layers:
if layer.name != "dense_new":
layer.trainable = False
else:
layer.trainable = True
# Verify
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_HEAD,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
return history
def train_finetune(model, train_ds, val_ds):
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
print(f"\n{'=' * 60}")
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_FINETUNE}")
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
# Find the MobileNetV2 functional module
# The backbone is a Functional model inside the base model
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
# Unfreeze the last ~25 layers of the backbone
total_backbone_layers = len(mobilenet_layer.layers)
unfreeze_from = max(0, total_backbone_layers - 25)
print(
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
)
for i, layer in enumerate(mobilenet_layer.layers):
layer.trainable = i >= unfreeze_from
# Also unfreeze the new head
model.get_layer("dense_new").trainable = True
model.get_layer("dropout_new").trainable = True
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_FINETUNE,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-7,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
return history
# ─── Export ──────────────────────────────────────────────────────────────────
def export_models(model, class_mapping, index_to_class):
"""Export the trained model to .keras and TF.js formats."""
print(f"\n{'=' * 60}")
print("EXPORTING")
print(f"{'=' * 60}")
# 1. Save as .keras (for future retraining)
keras_path = OUTPUT_DIR / "model-finetuned.keras"
model.save(str(keras_path))
print(f" ✓ Saved .keras: {keras_path}")
# 2. Save class mapping alongside the model
mapping_path = OUTPUT_DIR / "class_mapping.json"
with open(mapping_path, "w") as f:
json.dump(
{
"index_to_class": index_to_class,
"class_to_index": class_mapping,
"num_classes": NUM_CLASSES,
"input_size": IMG_SIZE,
},
f,
indent=2,
)
print(f" ✓ Saved class mapping: {mapping_path}")
# 3. Export to TF.js format
tfjs_path = str(TFJS_OUTPUT)
if TFJS_OUTPUT.exists():
shutil.rmtree(tfjs_path)
try:
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, tfjs_path)
print(f" ✓ Saved TF.js: {tfjs_path}/")
for f in sorted(TFJS_OUTPUT.iterdir()):
size = f.stat().st_size
print(f" {f.name:<30s} {size:>10,} bytes")
except Exception as e:
print(f" ⚠ TF.js export failed: {e}")
print(
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
)
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
def cleanup_old_model():
"""Remove old model.json and shards from the directory."""
for f in OUTPUT_DIR.glob("model.json"):
print(f" Removing old: {f.name}")
f.unlink()
for f in OUTPUT_DIR.glob("group1-shard*"):
print(f" Removing old: {f.name}")
f.unlink()
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("PLANT DISEASE MODEL FINE-TUNER")
print("=" * 60)
# 1. Build class mapping
print("\n[1/5] Building class mapping...")
class_mapping, index_to_class = build_class_mapping()
print(
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
)
# 2. Verify dataset
print("\n[2/5] Verifying dataset...")
if not DATASET_DIR.exists():
print(f" ERROR: Dataset not found at {DATASET_DIR}")
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
sys.exit(1)
available, total = verify_dataset(class_mapping)
print_dataset_summary(available, total)
if total < 100:
print(f" WARNING: Only {total} images. Consider scraping more data.")
print(" Continue anyway? (y/n)")
# Continue regardless — user can decide
# 3. Load dataset
print("\n[3/5] Loading and augmenting dataset...")
train_ds, val_ds = load_dataset(class_mapping, available)
# 4. Build and train model
print("\n[4/5] Building model...")
model = build_model()
model.summary()
# Check if training should run
if total > 0:
train_head(model, train_ds, val_ds)
train_finetune(model, train_ds, val_ds)
# 5. Export
print("\n[5/5] Exporting models...")
cleanup_old_model()
export_models(model, class_mapping, index_to_class)
else:
print("\n Skipping training — no dataset available.")
sys.exit(1)
# ── Final Summary ────────────────────────────────────────────────────────
print(f"\n{'=' * 60}")
print("DONE! Model fine-tuned and exported.")
print(f"{'=' * 60}")
print("\nFiles created:")
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
print(f" {TFJS_OUTPUT / 'model.json'}")
print("\nTo update your app:")
print(" 1. Replace model files:")
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
print(" 2. Restart the dev server")
print(" 3. Test with: POST /api/identify")
print("\nNote: Update labels.ts if the class order changed.")
if __name__ == "__main__":
main()