scripting
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
537
apps/web/scripts/fine-tune-model.py
Normal file
537
apps/web/scripts/fine-tune-model.py
Normal file
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
fine-tune-model.py
|
||||
|
||||
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
|
||||
(93 diseases + healthy + unknown).
|
||||
|
||||
Pipeline:
|
||||
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
|
||||
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
|
||||
3. Freeze backbone, train only the new classification head
|
||||
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
|
||||
5. Export to TF.js GraphModel format
|
||||
6. Export to .keras for future retraining
|
||||
|
||||
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras import layers, optimizers, regularizers
|
||||
|
||||
# ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_PATH = (
|
||||
PROJECT_ROOT
|
||||
/ "public"
|
||||
/ "models"
|
||||
/ "plant-disease-classifier"
|
||||
/ "best_mnv2_pv_original.keras"
|
||||
)
|
||||
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
|
||||
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
|
||||
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
|
||||
|
||||
IMG_SIZE = 160 # Model input size
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_HEAD = 15 # Train just the new head
|
||||
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
|
||||
LEARNING_RATE_HEAD = 1e-3
|
||||
LEARNING_RATE_FINETUNE = 1e-5
|
||||
VALIDATION_SPLIT = 0.15
|
||||
|
||||
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
|
||||
|
||||
# ─── Class Mapping ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_class_mapping():
|
||||
"""
|
||||
Build a dict mapping dataset directory names → model class indices.
|
||||
Matches the ordering in labels.ts / diseases.json.
|
||||
|
||||
Index 0 = "healthy"
|
||||
Index 1-93 = disease IDs (in diseases.json order)
|
||||
Index 94 = "unknown" (no images — skip during training)
|
||||
"""
|
||||
with open(DISEASES_JSON) as f:
|
||||
diseases = json.load(f)
|
||||
|
||||
mapping = {"healthy": 0}
|
||||
for i, disease in enumerate(diseases):
|
||||
mapping[disease["id"]] = i + 1 # Index 1-93
|
||||
mapping["unknown"] = 94 # Not trained, but reserved
|
||||
|
||||
# Reverse mapping for predictions
|
||||
index_to_class = {v: k for k, v in mapping.items()}
|
||||
|
||||
return mapping, index_to_class
|
||||
|
||||
|
||||
def verify_dataset(mapping):
|
||||
"""Find which classes have images and how many."""
|
||||
available = {}
|
||||
total = 0
|
||||
|
||||
for class_id, class_idx in mapping.items():
|
||||
class_dir = DATASET_DIR / class_id
|
||||
if not class_dir.exists():
|
||||
continue
|
||||
|
||||
image_paths = sorted(class_dir.glob("*"))
|
||||
image_paths = [
|
||||
p
|
||||
for p in image_paths
|
||||
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||||
]
|
||||
|
||||
if image_paths:
|
||||
available[class_id] = {"index": class_idx, "count": len(image_paths)}
|
||||
total += len(image_paths)
|
||||
|
||||
return available, total
|
||||
|
||||
|
||||
def print_dataset_summary(available, total):
|
||||
"""Print a summary of what's available."""
|
||||
print(f"\n{'─' * 60}")
|
||||
print("DATASET SUMMARY")
|
||||
print(f"{'─' * 60}")
|
||||
print(f" Total images: {total}")
|
||||
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
|
||||
print(
|
||||
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
|
||||
)
|
||||
|
||||
# Count images per class
|
||||
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
|
||||
counts.sort(key=lambda x: x[1])
|
||||
|
||||
print("\n Images per class:")
|
||||
for idx, class_id, count in counts:
|
||||
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
|
||||
if class_id == "healthy":
|
||||
label += " ← 2× target"
|
||||
print(label)
|
||||
|
||||
# Stats
|
||||
class_counts = [v["count"] for v in available.values()]
|
||||
if class_counts:
|
||||
print(
|
||||
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
|
||||
)
|
||||
print(f"{'─' * 60}\n")
|
||||
|
||||
|
||||
# ─── Data Loading ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_dataset(mapping, available):
|
||||
"""
|
||||
Load images from the dataset directory.
|
||||
Returns train/validation datasets with augmentation.
|
||||
"""
|
||||
# Build file paths and labels
|
||||
file_paths = []
|
||||
labels = []
|
||||
|
||||
for class_id, info in available.items():
|
||||
class_dir = DATASET_DIR / class_id
|
||||
images = sorted(class_dir.glob("*"))
|
||||
images = [
|
||||
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
|
||||
]
|
||||
|
||||
for img_path in images:
|
||||
file_paths.append(str(img_path))
|
||||
labels.append(info["index"])
|
||||
|
||||
file_paths = np.array(file_paths)
|
||||
labels = np.array(labels)
|
||||
|
||||
# Shuffle
|
||||
indices = np.random.RandomState(42).permutation(len(file_paths))
|
||||
file_paths = file_paths[indices]
|
||||
labels = labels[indices]
|
||||
|
||||
# Split train/validation
|
||||
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
|
||||
train_paths, val_paths = file_paths[:split], file_paths[split:]
|
||||
train_labels, val_labels = labels[:split], labels[split:]
|
||||
|
||||
print(f" Train: {len(train_paths)} images")
|
||||
print(f" Val: {len(val_paths)} images")
|
||||
|
||||
# Parse function
|
||||
def parse_image(image_path, label):
|
||||
img = tf.io.read_file(image_path)
|
||||
img = tf.image.decode_image(img, channels=3, expand_animations=False)
|
||||
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
|
||||
img = tf.cast(img, tf.float32) / 255.0
|
||||
# ImageNet normalization (matching training-time preprocessing)
|
||||
mean = tf.constant([0.485, 0.456, 0.406])
|
||||
std = tf.constant([0.229, 0.224, 0.225])
|
||||
img = (img - mean) / std
|
||||
return img, label
|
||||
|
||||
def augment(image, label):
|
||||
"""Data augmentation for training set."""
|
||||
# Random horizontal flip
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
# Random rotation (±20°)
|
||||
image = tf.image.random_flip_up_down(image)
|
||||
# Random brightness
|
||||
image = tf.image.random_brightness(image, 0.15)
|
||||
# Random contrast
|
||||
image = tf.image.random_contrast(image, 0.8, 1.2)
|
||||
# Random saturation
|
||||
image = tf.image.random_saturation(image, 0.8, 1.2)
|
||||
# Random hue
|
||||
image = tf.image.random_hue(image, 0.05)
|
||||
# Random crop (after slightly scaling up)
|
||||
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
|
||||
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
|
||||
# Clip to valid range after augmentations
|
||||
image = tf.clip_by_value(image, -2.5, 2.5)
|
||||
return image, label
|
||||
|
||||
# Create tf.data datasets
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
|
||||
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
|
||||
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return train_ds, val_ds
|
||||
|
||||
|
||||
# ─── Model Building ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_model():
|
||||
"""
|
||||
Load the PlantVillage model and replace the classification head
|
||||
with a 95-class output.
|
||||
"""
|
||||
print(f"\nLoading base model from: {MODEL_PATH}")
|
||||
if not MODEL_PATH.exists():
|
||||
print(f"ERROR: Model not found at {MODEL_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
base_model = keras.models.load_model(str(MODEL_PATH))
|
||||
print(f" Base model loaded: {type(base_model).__name__}")
|
||||
print(f" Input shape: {base_model.input_shape}")
|
||||
print(f" Output shape: {base_model.output_shape}")
|
||||
|
||||
# Extract backbone — everything up to the GlobalAveragePooling2D
|
||||
# The model structure is:
|
||||
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
|
||||
backbone_output = base_model.get_layer("global_average_pooling2d").output
|
||||
print(" Using backbone output: global_average_pooling2d")
|
||||
|
||||
# Freeze all backbone layers initially
|
||||
# (we'll unfreeze later for fine-tuning)
|
||||
for layer in base_model.layers:
|
||||
if layer.name != "dense": # We'll replace this anyway
|
||||
layer.trainable = False
|
||||
|
||||
# Build new classification head
|
||||
x = backbone_output
|
||||
x = layers.Dropout(0.3, name="dropout_new")(x)
|
||||
x = layers.Dense(
|
||||
NUM_CLASSES,
|
||||
activation="softmax",
|
||||
name="dense_new",
|
||||
kernel_regularizer=regularizers.l2(1e-4),
|
||||
)(x)
|
||||
|
||||
# Create new model
|
||||
model = keras.Model(
|
||||
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
|
||||
)
|
||||
|
||||
print(f" New model input: {model.input_shape}")
|
||||
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
|
||||
|
||||
# Count trainable params
|
||||
backbone_trainable = sum(
|
||||
w.shape.num_elements()
|
||||
for layer in base_model.layers
|
||||
if layer.name != "dense"
|
||||
for w in layer.trainable_weights
|
||||
)
|
||||
head_trainable = sum(
|
||||
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
|
||||
)
|
||||
|
||||
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
|
||||
print(f" New head: {head_trainable:,} params (training)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ─── Training ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def train_head(model, train_ds, val_ds):
|
||||
"""Stage 1: Train only the new classification head."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STAGE 1: Training classification head")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Epochs: {EPOCHS_HEAD}")
|
||||
print(f" Learning rate: {LEARNING_RATE_HEAD}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
|
||||
# Freeze all backbone layers
|
||||
for layer in model.layers:
|
||||
if layer.name != "dense_new":
|
||||
layer.trainable = False
|
||||
else:
|
||||
layer.trainable = True
|
||||
|
||||
# Verify
|
||||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||||
total = sum(w.shape.num_elements() for w in model.weights)
|
||||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||||
)
|
||||
|
||||
history = model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS_HEAD,
|
||||
verbose=1,
|
||||
callbacks=[
|
||||
keras.callbacks.EarlyStopping(
|
||||
monitor="val_accuracy",
|
||||
patience=3,
|
||||
restore_best_weights=True,
|
||||
),
|
||||
keras.callbacks.ReduceLROnPlateau(
|
||||
monitor="val_loss",
|
||||
factor=0.5,
|
||||
patience=2,
|
||||
min_lr=1e-6,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
final_val_acc = history.history["val_accuracy"][-1]
|
||||
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
|
||||
return history
|
||||
|
||||
|
||||
def train_finetune(model, train_ds, val_ds):
|
||||
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Epochs: {EPOCHS_FINETUNE}")
|
||||
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
|
||||
|
||||
# Find the MobileNetV2 functional module
|
||||
# The backbone is a Functional model inside the base model
|
||||
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
|
||||
|
||||
# Unfreeze the last ~25 layers of the backbone
|
||||
total_backbone_layers = len(mobilenet_layer.layers)
|
||||
unfreeze_from = max(0, total_backbone_layers - 25)
|
||||
print(
|
||||
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
|
||||
)
|
||||
|
||||
for i, layer in enumerate(mobilenet_layer.layers):
|
||||
layer.trainable = i >= unfreeze_from
|
||||
|
||||
# Also unfreeze the new head
|
||||
model.get_layer("dense_new").trainable = True
|
||||
model.get_layer("dropout_new").trainable = True
|
||||
|
||||
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
|
||||
total = sum(w.shape.num_elements() for w in model.weights)
|
||||
print(f" Trainable params: {trainable:,} / {total:,} total")
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
|
||||
)
|
||||
|
||||
history = model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS_FINETUNE,
|
||||
verbose=1,
|
||||
callbacks=[
|
||||
keras.callbacks.EarlyStopping(
|
||||
monitor="val_accuracy",
|
||||
patience=3,
|
||||
restore_best_weights=True,
|
||||
),
|
||||
keras.callbacks.ReduceLROnPlateau(
|
||||
monitor="val_loss",
|
||||
factor=0.5,
|
||||
patience=2,
|
||||
min_lr=1e-7,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
final_val_acc = history.history["val_accuracy"][-1]
|
||||
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
|
||||
return history
|
||||
|
||||
|
||||
# ─── Export ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def export_models(model, class_mapping, index_to_class):
|
||||
"""Export the trained model to .keras and TF.js formats."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("EXPORTING")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Save as .keras (for future retraining)
|
||||
keras_path = OUTPUT_DIR / "model-finetuned.keras"
|
||||
model.save(str(keras_path))
|
||||
print(f" ✓ Saved .keras: {keras_path}")
|
||||
|
||||
# 2. Save class mapping alongside the model
|
||||
mapping_path = OUTPUT_DIR / "class_mapping.json"
|
||||
with open(mapping_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"index_to_class": index_to_class,
|
||||
"class_to_index": class_mapping,
|
||||
"num_classes": NUM_CLASSES,
|
||||
"input_size": IMG_SIZE,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f" ✓ Saved class mapping: {mapping_path}")
|
||||
|
||||
# 3. Export to TF.js format
|
||||
tfjs_path = str(TFJS_OUTPUT)
|
||||
if TFJS_OUTPUT.exists():
|
||||
shutil.rmtree(tfjs_path)
|
||||
|
||||
try:
|
||||
import tensorflowjs as tfjs
|
||||
|
||||
tfjs.converters.save_keras_model(model, tfjs_path)
|
||||
print(f" ✓ Saved TF.js: {tfjs_path}/")
|
||||
for f in sorted(TFJS_OUTPUT.iterdir()):
|
||||
size = f.stat().st_size
|
||||
print(f" {f.name:<30s} {size:>10,} bytes")
|
||||
except Exception as e:
|
||||
print(f" ⚠ TF.js export failed: {e}")
|
||||
print(
|
||||
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
|
||||
)
|
||||
|
||||
|
||||
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def cleanup_old_model():
|
||||
"""Remove old model.json and shards from the directory."""
|
||||
for f in OUTPUT_DIR.glob("model.json"):
|
||||
print(f" Removing old: {f.name}")
|
||||
f.unlink()
|
||||
for f in OUTPUT_DIR.glob("group1-shard*"):
|
||||
print(f" Removing old: {f.name}")
|
||||
f.unlink()
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("PLANT DISEASE MODEL FINE-TUNER")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Build class mapping
|
||||
print("\n[1/5] Building class mapping...")
|
||||
class_mapping, index_to_class = build_class_mapping()
|
||||
print(
|
||||
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
|
||||
)
|
||||
|
||||
# 2. Verify dataset
|
||||
print("\n[2/5] Verifying dataset...")
|
||||
if not DATASET_DIR.exists():
|
||||
print(f" ERROR: Dataset not found at {DATASET_DIR}")
|
||||
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
|
||||
sys.exit(1)
|
||||
|
||||
available, total = verify_dataset(class_mapping)
|
||||
print_dataset_summary(available, total)
|
||||
|
||||
if total < 100:
|
||||
print(f" WARNING: Only {total} images. Consider scraping more data.")
|
||||
print(" Continue anyway? (y/n)")
|
||||
# Continue regardless — user can decide
|
||||
|
||||
# 3. Load dataset
|
||||
print("\n[3/5] Loading and augmenting dataset...")
|
||||
train_ds, val_ds = load_dataset(class_mapping, available)
|
||||
|
||||
# 4. Build and train model
|
||||
print("\n[4/5] Building model...")
|
||||
model = build_model()
|
||||
model.summary()
|
||||
|
||||
# Check if training should run
|
||||
if total > 0:
|
||||
train_head(model, train_ds, val_ds)
|
||||
train_finetune(model, train_ds, val_ds)
|
||||
|
||||
# 5. Export
|
||||
print("\n[5/5] Exporting models...")
|
||||
cleanup_old_model()
|
||||
export_models(model, class_mapping, index_to_class)
|
||||
else:
|
||||
print("\n Skipping training — no dataset available.")
|
||||
sys.exit(1)
|
||||
|
||||
# ── Final Summary ────────────────────────────────────────────────────────
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE! Model fine-tuned and exported.")
|
||||
print(f"{'=' * 60}")
|
||||
print("\nFiles created:")
|
||||
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
|
||||
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
|
||||
print(f" {TFJS_OUTPUT / 'model.json'}")
|
||||
print("\nTo update your app:")
|
||||
print(" 1. Replace model files:")
|
||||
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
|
||||
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
|
||||
print(" 2. Restart the dev server")
|
||||
print(" 3. Test with: POST /api/identify")
|
||||
print("\nNote: Update labels.ts if the class order changed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
660
apps/web/scripts/scrape-training-dataset.ts
Normal file
660
apps/web/scripts/scrape-training-dataset.ts
Normal file
@@ -0,0 +1,660 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* scrape-training-dataset.ts
|
||||
*
|
||||
* Collects a training dataset for fine-tuning by scraping DuckDuckGo image search.
|
||||
*
|
||||
* Targets:
|
||||
* - 200 images per disease class (93 diseases)
|
||||
* - 400 images for the "healthy" class
|
||||
* - Full resolution images stored in data/dataset/{class_id}/
|
||||
*
|
||||
* DuckDuckGo approach (no API key needed):
|
||||
* 1. Fetch the main search page to extract a vqd (query) token
|
||||
* 2. Use the vqd token to paginate through image results
|
||||
* 3. Download each image to the dataset directory
|
||||
*
|
||||
* Usage: cd apps/web && npx tsx scripts/scrape-training-dataset.ts
|
||||
*
|
||||
* Progress is tracked in data/dataset/.progress.json — interrupt and resume safely.
|
||||
*/
|
||||
|
||||
import "dotenv/config";
|
||||
import { readFileSync, writeFileSync, existsSync, mkdirSync, readdirSync, statSync } from "fs";
|
||||
import { resolve, extname, join } from "path";
|
||||
|
||||
// ─── Config ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const DISEASES_JSON = resolve(__dirname, "../src/data/diseases.json");
|
||||
const PLANTS_JSON = resolve(__dirname, "../src/data/plants.json");
|
||||
|
||||
const DATASET_DIR = resolve(__dirname, "../data/dataset");
|
||||
const PROGRESS_FILE = resolve(DATASET_DIR, ".progress.json");
|
||||
|
||||
/** Target images per disease class */
|
||||
const TARGET_PER_DISEASE = 200;
|
||||
|
||||
/** Target images for the "healthy" class (2× normal) */
|
||||
const TARGET_HEALTHY = 400;
|
||||
|
||||
/** Delay between DuckDuckGo search API calls (ms) */
|
||||
const SEARCH_DELAY = 1500;
|
||||
|
||||
/** Delay between image downloads (ms) */
|
||||
const DOWNLOAD_DELAY = 300;
|
||||
|
||||
/** Max concurrent downloads */
|
||||
const CONCURRENT_DOWNLOADS = 5;
|
||||
|
||||
/** Minimum image size in bytes to accept (reject tiny placeholders) */
|
||||
const MIN_IMAGE_SIZE = 10_000; // 10KB
|
||||
|
||||
/** Maximum image size in bytes */
|
||||
const MAX_IMAGE_SIZE = 10 * 1024 * 1024; // 10MB
|
||||
|
||||
/** Allowed image content types */
|
||||
const ALLOWED_CONTENT_TYPES = ["image/jpeg", "image/jpg", "image/png", "image/webp", "image/gif"];
|
||||
|
||||
/** Allowed file extensions */
|
||||
const ALLOWED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"];
|
||||
|
||||
/** User agent for requests */
|
||||
const UA =
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36";
|
||||
|
||||
// ─── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
interface DiseaseSeed {
|
||||
id: string;
|
||||
plantId: string;
|
||||
name: string;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
interface PlantSeed {
|
||||
id: string;
|
||||
commonName: string;
|
||||
scientificName: string;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
interface DuckDuckGoImageResult {
|
||||
image: string;
|
||||
title: string;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
height: number;
|
||||
width: number;
|
||||
}
|
||||
|
||||
interface ClassProgress {
|
||||
count: number;
|
||||
downloaded: number;
|
||||
failed: number;
|
||||
skipped: number;
|
||||
/** URLs we've already seen (to avoid duplicates) */
|
||||
seenUrls: string[];
|
||||
/** Whether we've exhausted search results */
|
||||
exhausted: boolean;
|
||||
}
|
||||
|
||||
interface Progress {
|
||||
lastUpdated: string;
|
||||
classes: Record<string, ClassProgress>;
|
||||
}
|
||||
|
||||
/** Class ID for healthy plants */
|
||||
const HEALTHY_CLASS = "healthy";
|
||||
|
||||
// ─── DuckDuckGo API ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Extract the vqd token from DuckDuckGo's search page.
|
||||
* Required for paginating image results.
|
||||
*/
|
||||
async function getVqdToken(query: string): Promise<string> {
|
||||
const url = `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`;
|
||||
|
||||
const res = await fetch(url, {
|
||||
headers: { "User-Agent": UA, Accept: "text/html" },
|
||||
signal: AbortSignal.timeout(15_000),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(`Failed to get vqd token: ${res.status}`);
|
||||
}
|
||||
|
||||
const html = await res.text();
|
||||
|
||||
// Extract vqd token from the HTML
|
||||
// Format: vqd='<token>' or vqd="<token>"
|
||||
const match = html.match(/vqd['"]?\s*[:=]\s*['"]([a-f0-9-]+)['"]/);
|
||||
if (!match) {
|
||||
throw new Error(`Could not extract vqd token from DuckDuckGo response for "${query}"`);
|
||||
}
|
||||
|
||||
return match[1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch a page of DuckDuckGo image results.
|
||||
*/
|
||||
async function searchImagesDuckDuckGo(
|
||||
query: string,
|
||||
vqd: string,
|
||||
page: number,
|
||||
): Promise<DuckDuckGoImageResult[]> {
|
||||
const url = `https://duckduckgo.com/i.js?q=${encodeURIComponent(query)}&vqd=${vqd}&o=json&p=${page}&f=,,,`;
|
||||
|
||||
const res = await fetch(url, {
|
||||
headers: {
|
||||
"User-Agent": UA,
|
||||
Accept: "application/json",
|
||||
Referer: `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`,
|
||||
},
|
||||
signal: AbortSignal.timeout(15_000),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
if (res.status === 429) {
|
||||
console.warn(" ⚠ Rate limited (429). Waiting 10s...");
|
||||
await sleep(10_000);
|
||||
return searchImagesDuckDuckGo(query, vqd, page); // Retry
|
||||
}
|
||||
if (res.status === 403) {
|
||||
console.warn(" ⚠ Forbidden (403). Token may have expired.");
|
||||
return []; // Token expired — no more pages
|
||||
}
|
||||
throw new Error(`DuckDuckGo search failed: ${res.status}`);
|
||||
}
|
||||
|
||||
const data = (await res.json()) as { results: DuckDuckGoImageResult[] };
|
||||
return data.results ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Search DuckDuckGo images, automatically paginating to collect up to `target` results.
|
||||
* Returns unique image URLs.
|
||||
*/
|
||||
async function collectImages(
|
||||
query: string,
|
||||
target: number,
|
||||
seenUrls: Set<string>,
|
||||
): Promise<{ urls: string[]; exhausted: boolean }> {
|
||||
const results: string[] = [];
|
||||
let page = 1;
|
||||
let exhausted = false;
|
||||
let consecutiveEmpty = 0;
|
||||
|
||||
// Get vqd token
|
||||
let vqd: string;
|
||||
try {
|
||||
vqd = await getVqdToken(query);
|
||||
} catch (err) {
|
||||
console.warn(` ⚠ Failed to get vqd token: ${err instanceof Error ? err.message : "unknown"}`);
|
||||
return { urls: [], exhausted: true };
|
||||
}
|
||||
|
||||
while (results.length < target) {
|
||||
await sleep(SEARCH_DELAY);
|
||||
|
||||
let pageResults: DuckDuckGoImageResult[];
|
||||
try {
|
||||
pageResults = await searchImagesDuckDuckGo(query, vqd, page);
|
||||
} catch (err) {
|
||||
console.warn(` ⚠ Search error: ${err instanceof Error ? err.message : "unknown"}`);
|
||||
break;
|
||||
}
|
||||
|
||||
if (pageResults.length === 0) {
|
||||
consecutiveEmpty++;
|
||||
if (consecutiveEmpty >= 3) {
|
||||
exhausted = true;
|
||||
break;
|
||||
}
|
||||
page++;
|
||||
continue;
|
||||
}
|
||||
|
||||
consecutiveEmpty = 0;
|
||||
let newCount = 0;
|
||||
|
||||
for (const r of pageResults) {
|
||||
if (results.length >= target) break;
|
||||
|
||||
const imgUrl = r.image || r.url;
|
||||
|
||||
// Skip if we've already seen this URL
|
||||
if (seenUrls.has(imgUrl)) continue;
|
||||
|
||||
// Validate URL looks like an image
|
||||
const ext = extname(new URL(imgUrl).pathname).toLowerCase();
|
||||
if (!ALLOWED_EXTENSIONS.includes(ext) && !ext) {
|
||||
// No extension - still try, could be a CDN URL
|
||||
}
|
||||
|
||||
seenUrls.add(imgUrl);
|
||||
results.push(imgUrl);
|
||||
newCount++;
|
||||
}
|
||||
|
||||
if (newCount === 0 && pageResults.every((r) => seenUrls.has(r.image || r.url))) {
|
||||
// All results on this page were already seen
|
||||
page++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (results.length < target) {
|
||||
page++;
|
||||
}
|
||||
}
|
||||
|
||||
return { urls: results.slice(0, target), exhausted };
|
||||
}
|
||||
|
||||
// ─── Image Download ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Download a single image from a URL to the target path.
|
||||
* Returns true if successful, false otherwise.
|
||||
*/
|
||||
async function downloadImage(url: string, destPath: string): Promise<boolean> {
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
headers: { "User-Agent": UA, Accept: "image/webp,image/png,image/jpeg" },
|
||||
signal: AbortSignal.timeout(15_000),
|
||||
});
|
||||
|
||||
if (!res.ok) return false;
|
||||
|
||||
const contentType = res.headers.get("content-type") || "";
|
||||
const contentLength = parseInt(res.headers.get("content-length") || "0", 10);
|
||||
|
||||
// Validate content type
|
||||
if (!ALLOWED_CONTENT_TYPES.some((t) => contentType.includes(t))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if (contentLength > 0 && contentLength < MIN_IMAGE_SIZE) return false;
|
||||
if (contentLength > MAX_IMAGE_SIZE) return false;
|
||||
|
||||
const buffer = Buffer.from(await res.arrayBuffer());
|
||||
|
||||
// Double-check actual buffer size
|
||||
if (buffer.length < MIN_IMAGE_SIZE) return false;
|
||||
if (buffer.length > MAX_IMAGE_SIZE) return false;
|
||||
|
||||
// Determine correct extension from content type or URL
|
||||
let ext = extname(new URL(url).pathname).toLowerCase();
|
||||
if (!ALLOWED_EXTENSIONS.includes(ext)) {
|
||||
// Map from content type
|
||||
if (contentType.includes("jpeg") || contentType.includes("jpg")) ext = ".jpg";
|
||||
else if (contentType.includes("png")) ext = ".png";
|
||||
else if (contentType.includes("webp")) ext = ".webp";
|
||||
else ext = ".jpg"; // Default
|
||||
}
|
||||
|
||||
const filePath = destPath.replace(/\.\w+$/, ext);
|
||||
writeFileSync(filePath, buffer);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Download multiple images concurrently, respecting a per-download delay.
|
||||
*/
|
||||
async function downloadBatch(
|
||||
urls: string[],
|
||||
classDir: string,
|
||||
startIndex: number,
|
||||
): Promise<{ downloaded: number; failed: number; lastIndex: number }> {
|
||||
let downloaded = 0;
|
||||
let failed = 0;
|
||||
let index = startIndex;
|
||||
|
||||
// Process in chunks to control concurrency
|
||||
for (let i = 0; i < urls.length; i += CONCURRENT_DOWNLOADS) {
|
||||
const chunk = urls.slice(i, i + CONCURRENT_DOWNLOADS);
|
||||
|
||||
const results = await Promise.all(
|
||||
chunk.map(async (url) => {
|
||||
const paddedIndex = String(index).padStart(4, "0");
|
||||
const destPath = resolve(classDir, `img_${paddedIndex}.jpg`);
|
||||
|
||||
const success = await downloadImage(url, destPath);
|
||||
await sleep(DOWNLOAD_DELAY);
|
||||
return { success, index: index++ };
|
||||
}),
|
||||
);
|
||||
|
||||
for (const r of results) {
|
||||
if (r.success) downloaded++;
|
||||
else failed++;
|
||||
}
|
||||
}
|
||||
|
||||
return { downloaded, failed, lastIndex: index };
|
||||
}
|
||||
|
||||
// ─── Progress Tracking ──────────────────────────────────────────────────────
|
||||
|
||||
function loadProgress(): Progress {
|
||||
if (!existsSync(PROGRESS_FILE)) {
|
||||
return { lastUpdated: new Date().toISOString(), classes: {} };
|
||||
}
|
||||
return JSON.parse(readFileSync(PROGRESS_FILE, "utf-8")) as Progress;
|
||||
}
|
||||
|
||||
function saveProgress(progress: Progress): void {
|
||||
progress.lastUpdated = new Date().toISOString();
|
||||
writeFileSync(PROGRESS_FILE, JSON.stringify(progress, null, 2));
|
||||
}
|
||||
|
||||
function getClassProgress(progress: Progress, classId: string): ClassProgress {
|
||||
if (!progress.classes[classId]) {
|
||||
progress.classes[classId] = {
|
||||
count: 0,
|
||||
downloaded: 0,
|
||||
failed: 0,
|
||||
skipped: 0,
|
||||
seenUrls: [],
|
||||
exhausted: false,
|
||||
};
|
||||
}
|
||||
return progress.classes[classId];
|
||||
}
|
||||
|
||||
// ─── Search Query Building ──────────────────────────────────────────────────
|
||||
|
||||
function buildSearchQueries(disease: DiseaseSeed, plant: PlantSeed | null): string[] {
|
||||
const name = disease.name;
|
||||
const plantName = plant?.commonName || disease.plantId;
|
||||
|
||||
return [
|
||||
`${name} ${plantName} leaf disease`,
|
||||
`${plantName} ${name} symptoms`,
|
||||
`${name} plant disease`,
|
||||
`${plantName} diseased leaf`,
|
||||
];
|
||||
}
|
||||
|
||||
function buildHealthyQueries(plant: PlantSeed): string[] {
|
||||
return [
|
||||
`healthy ${plant.commonName} leaf`,
|
||||
`${plant.commonName} leaf closeup`,
|
||||
`healthy ${plant.commonName} plant`,
|
||||
`${plant.commonName} foliage`,
|
||||
];
|
||||
}
|
||||
|
||||
// ─── Dataset Collection ─────────────────────────────────────────────────────
|
||||
|
||||
async function collectClassImages(
|
||||
classId: string,
|
||||
queries: string[],
|
||||
target: number,
|
||||
progress: Progress,
|
||||
classDir: string,
|
||||
): Promise<void> {
|
||||
const cp = getClassProgress(progress, classId);
|
||||
const seenUrls = new Set(cp.seenUrls);
|
||||
|
||||
if (cp.count >= target) {
|
||||
console.log(` ✓ Already have ${cp.count}/${target} images`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (cp.exhausted) {
|
||||
console.log(` ✓ Already exhausted search results (${cp.count}/${target} images)`);
|
||||
return;
|
||||
}
|
||||
|
||||
mkdirSync(classDir, { recursive: true });
|
||||
|
||||
const totalUrls: string[] = [];
|
||||
let exhausted = false;
|
||||
|
||||
// Search with each query until we hit the target
|
||||
for (const query of queries) {
|
||||
if (totalUrls.length >= target) break;
|
||||
|
||||
console.log(` Searching: "${query}"...`);
|
||||
const result = await collectImages(query, target - totalUrls.length, seenUrls);
|
||||
|
||||
totalUrls.push(...result.urls);
|
||||
cp.seenUrls = Array.from(seenUrls);
|
||||
|
||||
if (result.exhausted) {
|
||||
exhausted = true;
|
||||
}
|
||||
|
||||
if (totalUrls.length >= target) break;
|
||||
}
|
||||
|
||||
if (totalUrls.length === 0) {
|
||||
cp.exhausted = exhausted;
|
||||
saveProgress(progress);
|
||||
console.log(` ✗ No images found for "${classId}"`);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(` Found ${totalUrls.length} unique image URLs. Downloading...`);
|
||||
|
||||
// Download the images
|
||||
const { downloaded, failed } = await downloadBatch(totalUrls, classDir, cp.count);
|
||||
|
||||
cp.count += downloaded;
|
||||
cp.downloaded += downloaded;
|
||||
cp.failed += failed;
|
||||
cp.exhausted = exhausted;
|
||||
|
||||
saveProgress(progress);
|
||||
|
||||
const pct = Math.round((cp.count / target) * 100);
|
||||
console.log(
|
||||
` ${downloaded > 0 ? "✓" : "✗"} Got ${downloaded} images (${failed} failed). Total: ${cp.count}/${target} (${pct}%)`,
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Main ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async function main() {
|
||||
console.log("=".repeat(60));
|
||||
console.log("PLANT DISEASE DATASET COLLECTOR");
|
||||
console.log("=".repeat(60));
|
||||
|
||||
// Load knowledge base
|
||||
const diseases = JSON.parse(readFileSync(DISEASES_JSON, "utf-8")) as DiseaseSeed[];
|
||||
const plants = JSON.parse(readFileSync(PLANTS_JSON, "utf-8")) as PlantSeed[];
|
||||
|
||||
const plantMap = new Map<string, PlantSeed>();
|
||||
for (const p of plants) {
|
||||
plantMap.set(p.id, p);
|
||||
}
|
||||
|
||||
console.log(`\nLoaded ${diseases.length} diseases, ${plants.length} plants`);
|
||||
console.log(
|
||||
`Target: ${TARGET_PER_DISEASE} images/disease (×${diseases.length} = ${diseases.length * TARGET_PER_DISEASE})`,
|
||||
);
|
||||
console.log(`Target: ${TARGET_HEALTHY} images for "healthy" class`);
|
||||
console.log(`Output: ${DATASET_DIR}/`);
|
||||
console.log("");
|
||||
|
||||
// Load progress
|
||||
mkdirSync(DATASET_DIR, { recursive: true });
|
||||
const progress = loadProgress();
|
||||
|
||||
const startTime = Date.now();
|
||||
|
||||
// ── Phase 1: Disease classes ──────────────────────────────────────────────
|
||||
|
||||
console.log("─".repeat(60));
|
||||
console.log("PHASE 1: Disease Images");
|
||||
console.log("─".repeat(60));
|
||||
|
||||
for (let i = 0; i < diseases.length; i++) {
|
||||
const disease = diseases[i];
|
||||
const plant = plantMap.get(disease.plantId) ?? null;
|
||||
const classDir = resolve(DATASET_DIR, disease.id);
|
||||
const queries = buildSearchQueries(disease, plant);
|
||||
|
||||
const pct = Math.round((i / diseases.length) * 100);
|
||||
console.log(`\n[${i + 1}/${diseases.length}] (${pct}%) ${disease.name} (${disease.id})`);
|
||||
|
||||
await collectClassImages(disease.id, queries, TARGET_PER_DISEASE, progress, classDir);
|
||||
}
|
||||
|
||||
// ── Phase 2: Healthy class ────────────────────────────────────────────────
|
||||
|
||||
console.log("\n" + "─".repeat(60));
|
||||
console.log("PHASE 2: Healthy Plant Images");
|
||||
console.log("─".repeat(60));
|
||||
|
||||
const healthyDir = resolve(DATASET_DIR, HEALTHY_CLASS);
|
||||
const healthyCp = getClassProgress(progress, HEALTHY_CLASS);
|
||||
const healthySeen = new Set(healthyCp.seenUrls);
|
||||
|
||||
if (healthyCp.count >= TARGET_HEALTHY) {
|
||||
console.log(`\n ✓ Already have ${healthyCp.count}/${TARGET_HEALTHY} healthy images`);
|
||||
} else {
|
||||
// Build a pool of healthy plant queries
|
||||
const allHealthyQueries: string[] = [];
|
||||
for (const plant of plants) {
|
||||
allHealthyQueries.push(...buildHealthyQueries(plant));
|
||||
}
|
||||
|
||||
const totalHealthyUrls: string[] = [];
|
||||
let healthyExhausted = false;
|
||||
|
||||
for (const query of allHealthyQueries) {
|
||||
if (totalHealthyUrls.length >= TARGET_HEALTHY) break;
|
||||
if (healthyExhausted) break;
|
||||
|
||||
console.log(`\n Searching: "${query}"...`);
|
||||
const result = await collectImages(
|
||||
query,
|
||||
TARGET_HEALTHY - totalHealthyUrls.length,
|
||||
healthySeen,
|
||||
);
|
||||
|
||||
totalHealthyUrls.push(...result.urls);
|
||||
|
||||
if (result.exhausted) {
|
||||
healthyExhausted = true;
|
||||
}
|
||||
}
|
||||
|
||||
healthyCp.seenUrls = Array.from(healthySeen);
|
||||
|
||||
if (totalHealthyUrls.length > 0) {
|
||||
console.log(`\n Found ${totalHealthyUrls.length} healthy image URLs. Downloading...`);
|
||||
const { downloaded, failed } = await downloadBatch(
|
||||
totalHealthyUrls,
|
||||
healthyDir,
|
||||
healthyCp.count,
|
||||
);
|
||||
|
||||
healthyCp.count += downloaded;
|
||||
healthyCp.downloaded += downloaded;
|
||||
healthyCp.failed += failed;
|
||||
healthyCp.exhausted = healthyExhausted;
|
||||
|
||||
const pct = Math.round((healthyCp.count / TARGET_HEALTHY) * 100);
|
||||
console.log(
|
||||
` Got ${downloaded} images (${failed} failed). Total: ${healthyCp.count}/${TARGET_HEALTHY} (${pct}%)`,
|
||||
);
|
||||
} else {
|
||||
healthyCp.exhausted = true;
|
||||
console.log(` ✗ No healthy images found`);
|
||||
}
|
||||
|
||||
saveProgress(progress);
|
||||
}
|
||||
|
||||
// ── Summary ────────────────────────────────────────────────────────────────
|
||||
|
||||
const elapsed = Math.round((Date.now() - startTime) / 1000);
|
||||
const mins = Math.floor(elapsed / 60);
|
||||
const secs = elapsed % 60;
|
||||
|
||||
let totalDownloaded = 0;
|
||||
let totalFailed = 0;
|
||||
let totalTarget = 0;
|
||||
|
||||
for (const [classId, cp] of Object.entries(progress.classes)) {
|
||||
totalDownloaded += cp.downloaded || 0;
|
||||
totalFailed += cp.failed || 0;
|
||||
totalTarget += classId === HEALTHY_CLASS ? TARGET_HEALTHY : TARGET_PER_DISEASE;
|
||||
}
|
||||
|
||||
const totalSize = await getDatasetSize();
|
||||
const sizeGb = (totalSize / (1024 * 1024 * 1024)).toFixed(2);
|
||||
|
||||
console.log("\n" + "=".repeat(60));
|
||||
console.log("COMPLETE");
|
||||
console.log("=".repeat(60));
|
||||
console.log(` Time: ${mins}m ${secs}s`);
|
||||
console.log(` Downloaded: ${totalDownloaded} images`);
|
||||
console.log(` Failed: ${totalFailed} images`);
|
||||
console.log(` Target: ${totalTarget} images`);
|
||||
console.log(` Dataset size: ${sizeGb} GB`);
|
||||
console.log(` Dataset location: ${DATASET_DIR}/`);
|
||||
console.log("");
|
||||
console.log("Next steps:");
|
||||
console.log(" 1. Run the fine-tuning script to train on this dataset");
|
||||
console.log(" 2. The fine-tuning script will resize to 160×160 and augment");
|
||||
console.log("=".repeat(60));
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate total size of the dataset directory.
|
||||
*/
|
||||
async function getDatasetSize(): Promise<number> {
|
||||
let total = 0;
|
||||
if (!existsSync(DATASET_DIR)) return 0;
|
||||
|
||||
const entries = readdirSync(DATASET_DIR, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
if (!entry.name.startsWith(".")) {
|
||||
const fullPath = resolve(DATASET_DIR, entry.name);
|
||||
if (entry.isDirectory()) {
|
||||
total += dirSize(fullPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
function dirSize(dirPath: string): number {
|
||||
let total = 0;
|
||||
try {
|
||||
const entries = readdirSync(dirPath, { withFileTypes: true });
|
||||
for (const entry of entries) {
|
||||
const fullPath = join(dirPath, entry.name);
|
||||
if (entry.isFile()) {
|
||||
total += statSync(fullPath).size;
|
||||
} else if (entry.isDirectory()) {
|
||||
total += dirSize(fullPath);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// skip errors
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Fatal error:", err);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -173,14 +173,14 @@ describe("imageToTensor", () => {
|
||||
|
||||
describe("tensorToBase64 / base64ToTensor", () => {
|
||||
it("round-trips tensor data correctly", () => {
|
||||
const imageData = createMockImageData(224, 224, 100, 150, 200);
|
||||
const imageData = createMockImageData(160, 160, 100, 150, 200);
|
||||
const original = imageToTensor(imageData);
|
||||
|
||||
const base64 = tensorToBase64(original);
|
||||
const decoded = base64ToTensor(base64);
|
||||
|
||||
expect(decoded.tensor.length).toBe(original.length);
|
||||
expect(decoded.shape).toEqual([3, 224, 224]);
|
||||
expect(decoded.shape).toEqual([3, 160, 160]);
|
||||
|
||||
// Check a few values match
|
||||
for (let i = 0; i < 10; i++) {
|
||||
@@ -197,9 +197,9 @@ describe("tensorToBase64 / base64ToTensor", () => {
|
||||
});
|
||||
|
||||
describe("getTensorShape", () => {
|
||||
it("returns [1, 3, 224, 224] by default", () => {
|
||||
it("returns [1, 3, 160, 160] by default", () => {
|
||||
const shape = getTensorShape();
|
||||
expect(shape).toEqual([1, 3, 224, 224]);
|
||||
expect(shape).toEqual([1, 3, 160, 160]);
|
||||
});
|
||||
|
||||
it("returns NCHW layout", () => {
|
||||
@@ -207,8 +207,8 @@ describe("getTensorShape", () => {
|
||||
expect(shape.length).toBe(4);
|
||||
expect(shape[0]).toBe(1); // batch
|
||||
expect(shape[1]).toBe(3); // channels
|
||||
expect(shape[2]).toBe(224); // height
|
||||
expect(shape[3]).toBe(224); // width
|
||||
expect(shape[2]).toBe(160); // height (model input size)
|
||||
expect(shape[3]).toBe(160); // width (model input size)
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -48,12 +48,7 @@ export const MAX_FILE_SIZE = 10 * 1024 * 1024;
|
||||
export const MIN_DIMENSION = 150;
|
||||
|
||||
/** Allowed MIME types */
|
||||
export const ALLOWED_MIME_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
] as const;
|
||||
export const ALLOWED_MIME_TYPES = ["image/png", "image/jpeg", "image/jpg", "image/webp"] as const;
|
||||
|
||||
export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number];
|
||||
|
||||
@@ -66,9 +61,7 @@ export const MAX_UPLOADS = 100;
|
||||
* Validate that a file is an acceptable image for upload.
|
||||
* Returns `{ ok: true }` or `{ ok: false, error: string }`.
|
||||
*/
|
||||
export function validateImageFile(file: File):
|
||||
| { ok: true }
|
||||
| { ok: false; error: string } {
|
||||
export function validateImageFile(file: File): { ok: true } | { ok: false; error: string } {
|
||||
// MIME type check
|
||||
if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) {
|
||||
return {
|
||||
@@ -127,10 +120,7 @@ export function validateImageDimensions(
|
||||
* @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224.
|
||||
* @returns ImageData at exactly `size × size`
|
||||
*/
|
||||
export async function resizeImage(
|
||||
file: File,
|
||||
size: number = getConfig().size,
|
||||
): Promise<ImageData> {
|
||||
export async function resizeImage(file: File, size: number = getConfig().size): Promise<ImageData> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
@@ -193,8 +183,7 @@ export function imageToTensor(imageData: ImageData): Float32Array {
|
||||
|
||||
// Normalize with ImageNet mean/std
|
||||
for (let c = 0; c < 3; c++) {
|
||||
const channel =
|
||||
c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
|
||||
const channel = c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
|
||||
const m = mean[c];
|
||||
const s = std[c];
|
||||
for (let i = 0; i < totalPixels; i++) {
|
||||
@@ -253,5 +242,3 @@ export function base64ToTensor(base64: string): {
|
||||
shape: envelope.shape as [number, number, number],
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ describe("createZeroTensor", () => {
|
||||
|
||||
it("all values are zero", () => {
|
||||
const tensor = createZeroTensor();
|
||||
expect(tensor.every(v => v === 0)).toBe(true);
|
||||
expect(tensor.every((v) => v === 0)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -114,12 +114,12 @@ describe("createRandomTensor", () => {
|
||||
|
||||
it("all values are finite", () => {
|
||||
const tensor = createRandomTensor();
|
||||
expect(tensor.every(v => Number.isFinite(v))).toBe(true);
|
||||
expect(tensor.every((v) => Number.isFinite(v))).toBe(true);
|
||||
});
|
||||
|
||||
it("produces varied values", () => {
|
||||
const tensor = createRandomTensor();
|
||||
const uniqueValues = new Set(tensor.map(v => v.toFixed(4)));
|
||||
const uniqueValues = new Set(tensor.map((v) => v.toFixed(4)));
|
||||
expect(uniqueValues.size).toBeGreaterThan(100);
|
||||
});
|
||||
|
||||
@@ -172,7 +172,7 @@ describe("runInference", () => {
|
||||
const result = await runInference(tensor);
|
||||
for (let i = 0; i < result.predictions.length - 1; i++) {
|
||||
expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
|
||||
result.predictions[i + 1].probability
|
||||
result.predictions[i + 1].probability,
|
||||
);
|
||||
}
|
||||
}, 10000);
|
||||
|
||||
@@ -69,9 +69,7 @@ export async function runInference(
|
||||
*/
|
||||
export function validateInput(tensor: Float32Array): void {
|
||||
if (!(tensor instanceof Float32Array)) {
|
||||
throw new Error(
|
||||
`Expected Float32Array input, got ${typeof tensor}`,
|
||||
);
|
||||
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
|
||||
}
|
||||
|
||||
if (tensor.length !== INPUT_SIZE) {
|
||||
@@ -84,9 +82,7 @@ export function validateInput(tensor: Float32Array): void {
|
||||
// Check for NaN/Infinity values
|
||||
for (let i = 0; i < tensor.length; i++) {
|
||||
if (!Number.isFinite(tensor[i])) {
|
||||
throw new Error(
|
||||
`Tensor contains non-finite value at index ${i}: ${tensor[i]}`,
|
||||
);
|
||||
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
/**
|
||||
* Unit tests for lib/ml/labels.ts
|
||||
*
|
||||
* Tests:
|
||||
* - INDEX_TO_DISEASE_ID maps index 0 to "healthy"
|
||||
* - INDEX_TO_DISEASE_ID maps last index to "unknown"
|
||||
* - INDEX_TO_DISEASE_ID maps intermediate indices to disease IDs
|
||||
* - DISEASE_ID_TO_INDEX is inverse of INDEX_TO_DISEASE_ID
|
||||
* - getDiseaseIdForIndex returns "unknown" for out-of-range
|
||||
* - getIndexForDiseaseId returns -1 for unknown ID
|
||||
* - isRealDisease correctly classifies healthy/unknown vs real diseases
|
||||
* - getAllDiseaseIds returns all disease IDs from knowledge base
|
||||
* - NUM_CLASSES equals expected count (diseases + healthy + unknown)
|
||||
* - Bidirectional mapping consistency
|
||||
* The model has 38 PlantVillage classes. Some map to the app's
|
||||
* knowledge base disease IDs, others map to "unknown".
|
||||
*
|
||||
* Known mappings:
|
||||
* - indices 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37 → "healthy"
|
||||
* - index 20 (Potato___Early_blight) → "early-blight"
|
||||
* - index 21 (Potato___Late_blight) → "late-blight"
|
||||
* - index 25 (Squash___Powdery_mildew) → "squash-powdery-mildew"
|
||||
* - index 26 (Strawberry___Leaf_scorch) → "strawberry-leaf-scorch"
|
||||
* - index 28 (Tomato___Bacterial_spot) → "bacterial-leaf-spot-tomato"
|
||||
* - index 29 (Tomato___Early_blight) → "early-blight" (duplicate)
|
||||
* - index 30 (Tomato___Late_blight) → "late-blight" (duplicate)
|
||||
* - index 32 (Tomato___Septoria_leaf_spot) → "septoria-leaf-spot"
|
||||
* - index 37 (Tomato___healthy) → "healthy"
|
||||
* - all others → "unknown"
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
@@ -23,143 +27,105 @@ import {
|
||||
isRealDisease,
|
||||
getAllDiseaseIds,
|
||||
NUM_CLASSES,
|
||||
HEALTHY_INDEX,
|
||||
FIRST_DISEASE_INDEX,
|
||||
UNKNOWN_INDEX,
|
||||
getPlantVillageClassName,
|
||||
} from "@/lib/ml/labels";
|
||||
import rawDiseases from "@/data/diseases.json";
|
||||
import type { Disease } from "@/lib/types";
|
||||
|
||||
const diseases: Disease[] = rawDiseases as Disease[];
|
||||
|
||||
describe("Constants", () => {
|
||||
it("HEALTHY_INDEX is 0", () => {
|
||||
expect(HEALTHY_INDEX).toBe(0);
|
||||
it("NUM_CLASSES is 38 (PlantVillage)", () => {
|
||||
expect(NUM_CLASSES).toBe(38);
|
||||
});
|
||||
|
||||
it("FIRST_DISEASE_INDEX is 1", () => {
|
||||
expect(FIRST_DISEASE_INDEX).toBe(1);
|
||||
});
|
||||
|
||||
it("UNKNOWN_INDEX is 1 + number of diseases", () => {
|
||||
expect(UNKNOWN_INDEX).toBe(1 + diseases.length);
|
||||
});
|
||||
|
||||
it("NUM_CLASSES is UNKNOWN_INDEX + 1", () => {
|
||||
expect(NUM_CLASSES).toBe(UNKNOWN_INDEX + 1);
|
||||
});
|
||||
|
||||
it("NUM_CLASSES equals diseases.length + 2 (healthy + unknown)", () => {
|
||||
expect(NUM_CLASSES).toBe(diseases.length + 2);
|
||||
});
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID", () => {
|
||||
it("maps index 0 to 'healthy'", () => {
|
||||
expect(INDEX_TO_DISEASE_ID[0]).toBe("healthy");
|
||||
});
|
||||
|
||||
it("maps last index to 'unknown'", () => {
|
||||
expect(INDEX_TO_DISEASE_ID[NUM_CLASSES - 1]).toBe("unknown");
|
||||
});
|
||||
|
||||
it("maps intermediate indices to disease IDs", () => {
|
||||
// Index 1 should be the first disease
|
||||
expect(INDEX_TO_DISEASE_ID[1]).toBe(diseases[0].id);
|
||||
// Index 2 should be the second disease
|
||||
expect(INDEX_TO_DISEASE_ID[2]).toBe(diseases[1].id);
|
||||
// Last disease index
|
||||
expect(INDEX_TO_DISEASE_ID[diseases.length]).toBe(diseases[diseases.length - 1].id);
|
||||
});
|
||||
|
||||
it("has exactly NUM_CLASSES entries", () => {
|
||||
const keys = Object.keys(INDEX_TO_DISEASE_ID);
|
||||
expect(keys.length).toBe(NUM_CLASSES);
|
||||
});
|
||||
|
||||
it("all mapped IDs are valid strings", () => {
|
||||
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
|
||||
expect(typeof id).toBe("string");
|
||||
expect(id.length).toBeGreaterThan(0);
|
||||
it("all 38 indices are mapped", () => {
|
||||
const keys = Object.keys(INDEX_TO_DISEASE_ID).map(Number);
|
||||
expect(keys.length).toBe(38);
|
||||
for (let i = 0; i < 38; i++) {
|
||||
expect(keys).toContain(i);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — healthy indices", () => {
|
||||
const healthyIndices = [3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37];
|
||||
|
||||
for (const idx of healthyIndices) {
|
||||
it(`index ${idx} maps to "healthy"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe("healthy");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — known disease mappings", () => {
|
||||
const cases: Array<{ index: number; expected: string; name: string }> = [
|
||||
{ index: 20, expected: "early-blight", name: "Potato___Early_blight" },
|
||||
{ index: 21, expected: "late-blight", name: "Potato___Late_blight" },
|
||||
{ index: 25, expected: "squash-powdery-mildew", name: "Squash___Powdery_mildew" },
|
||||
{ index: 26, expected: "strawberry-leaf-scorch", name: "Strawberry___Leaf_scorch" },
|
||||
{ index: 28, expected: "bacterial-leaf-spot-tomato", name: "Tomato___Bacterial_spot" },
|
||||
{ index: 29, expected: "early-blight", name: "Tomato___Early_blight" },
|
||||
{ index: 30, expected: "late-blight", name: "Tomato___Late_blight" },
|
||||
{ index: 32, expected: "septoria-leaf-spot", name: "Tomato___Septoria_leaf_spot" },
|
||||
];
|
||||
|
||||
for (const { index, expected, name } of cases) {
|
||||
it(`index ${index} (${name}) maps to "${expected}"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[index]).toBe(expected);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — unknown (unmapped) indices", () => {
|
||||
const unknownIndices = [0, 1, 2, 5, 7, 8, 9, 11, 12, 13, 15, 16, 18, 31, 33, 34, 35, 36];
|
||||
|
||||
for (const idx of unknownIndices) {
|
||||
it(`index ${idx} maps to "unknown"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe("unknown");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("DISEASE_ID_TO_INDEX", () => {
|
||||
it("maps 'healthy' to index 0", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(0);
|
||||
it("maps 'early-blight' to first occurrence (index 20)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["early-blight"]).toBe(20);
|
||||
});
|
||||
|
||||
it("maps 'unknown' to last index", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["unknown"]).toBe(NUM_CLASSES - 1);
|
||||
it("maps 'late-blight' to first occurrence (index 21)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["late-blight"]).toBe(21);
|
||||
});
|
||||
|
||||
it("maps disease IDs to correct indices", () => {
|
||||
for (let i = 0; i < diseases.length; i++) {
|
||||
expect(DISEASE_ID_TO_INDEX[diseases[i].id]).toBe(FIRST_DISEASE_INDEX + i);
|
||||
}
|
||||
it("maps 'septoria-leaf-spot' to index 32", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["septoria-leaf-spot"]).toBe(32);
|
||||
});
|
||||
|
||||
it("has exactly NUM_CLASSES entries", () => {
|
||||
const keys = Object.keys(DISEASE_ID_TO_INDEX);
|
||||
expect(keys.length).toBe(NUM_CLASSES);
|
||||
it("maps 'healthy' to index 3 (first healthy index)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Bidirectional mapping", () => {
|
||||
it("INDEX_TO_DISEASE_ID and DISEASE_ID_TO_INDEX are inverses", () => {
|
||||
for (const [idxStr, id] of Object.entries(INDEX_TO_DISEASE_ID)) {
|
||||
const idx = parseInt(idxStr);
|
||||
expect(DISEASE_ID_TO_INDEX[id]).toBe(idx);
|
||||
}
|
||||
});
|
||||
|
||||
it("round-trips for all disease IDs", () => {
|
||||
for (const [id, idx] of Object.entries(DISEASE_ID_TO_INDEX)) {
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
|
||||
}
|
||||
});
|
||||
|
||||
it("round-trips for all indices", () => {
|
||||
it("every index round-trips correctly", () => {
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
const id = INDEX_TO_DISEASE_ID[i];
|
||||
expect(DISEASE_ID_TO_INDEX[id]).toBe(i);
|
||||
const idx = DISEASE_ID_TO_INDEX[id];
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("getDiseaseIdForIndex", () => {
|
||||
it("returns 'healthy' for index 0", () => {
|
||||
expect(getDiseaseIdForIndex(0)).toBe("healthy");
|
||||
});
|
||||
|
||||
it("returns disease ID for valid disease index", () => {
|
||||
expect(getDiseaseIdForIndex(1)).toBe(diseases[0].id);
|
||||
});
|
||||
|
||||
it("returns 'unknown' for out-of-range positive index", () => {
|
||||
expect(getDiseaseIdForIndex(1000)).toBe("unknown");
|
||||
expect(getDiseaseIdForIndex(100)).toBe("unknown");
|
||||
});
|
||||
|
||||
it("returns 'unknown' for negative index", () => {
|
||||
expect(getDiseaseIdForIndex(-1)).toBe("unknown");
|
||||
});
|
||||
|
||||
it("returns 'unknown' for index past NUM_CLASSES", () => {
|
||||
expect(getDiseaseIdForIndex(NUM_CLASSES + 10)).toBe("unknown");
|
||||
it("returns correct ID for valid index", () => {
|
||||
expect(getDiseaseIdForIndex(20)).toBe("early-blight");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getIndexForDiseaseId", () => {
|
||||
it("returns 0 for 'healthy'", () => {
|
||||
expect(getIndexForDiseaseId("healthy")).toBe(0);
|
||||
});
|
||||
|
||||
it("returns correct index for known disease", () => {
|
||||
const idx = getIndexForDiseaseId(diseases[0].id);
|
||||
expect(idx).toBe(1);
|
||||
});
|
||||
|
||||
it("returns -1 for unknown disease ID", () => {
|
||||
expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1);
|
||||
});
|
||||
@@ -169,9 +135,7 @@ describe("getIndexForDiseaseId", () => {
|
||||
});
|
||||
|
||||
it("is case-insensitive", () => {
|
||||
const lowerIdx = getIndexForDiseaseId(diseases[0].id);
|
||||
const upperIdx = getIndexForDiseaseId(diseases[0].id.toUpperCase());
|
||||
expect(upperIdx).toBe(lowerIdx);
|
||||
expect(getIndexForDiseaseId("EARLY-BLIGHT")).toBe(20);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -184,10 +148,9 @@ describe("isRealDisease", () => {
|
||||
expect(isRealDisease("unknown")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true for actual disease IDs", () => {
|
||||
for (const disease of diseases) {
|
||||
expect(isRealDisease(disease.id)).toBe(true);
|
||||
}
|
||||
it("returns true for known disease IDs", () => {
|
||||
expect(isRealDisease("early-blight")).toBe(true);
|
||||
expect(isRealDisease("septoria-leaf-spot")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true for arbitrary non-special strings", () => {
|
||||
@@ -195,27 +158,37 @@ describe("isRealDisease", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPlantVillageClassName", () => {
|
||||
it("returns correct class name for tomato healthy", () => {
|
||||
expect(getPlantVillageClassName(37)).toBe("Tomato___healthy");
|
||||
});
|
||||
|
||||
it("returns correct class name for potato early blight", () => {
|
||||
expect(getPlantVillageClassName(20)).toBe("Potato___Early_blight");
|
||||
});
|
||||
|
||||
it("returns 'unknown' for out-of-range index", () => {
|
||||
expect(getPlantVillageClassName(100)).toBe("unknown");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAllDiseaseIds", () => {
|
||||
it("returns array of all disease IDs", () => {
|
||||
it("returns only mapped disease IDs", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
expect(ids.length).toBe(diseases.length);
|
||||
expect(ids).toContain("early-blight");
|
||||
expect(ids).toContain("late-blight");
|
||||
expect(ids).toContain("squash-powdery-mildew");
|
||||
expect(ids).toContain("strawberry-leaf-scorch");
|
||||
expect(ids).toContain("bacterial-leaf-spot-tomato");
|
||||
expect(ids).toContain("septoria-leaf-spot");
|
||||
});
|
||||
|
||||
it("excludes 'healthy'", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
expect(ids).not.toContain("healthy");
|
||||
expect(getAllDiseaseIds()).not.toContain("healthy");
|
||||
});
|
||||
|
||||
it("excludes 'unknown'", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
expect(ids).not.toContain("unknown");
|
||||
});
|
||||
|
||||
it("includes all disease IDs from knowledge base", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
for (const disease of diseases) {
|
||||
expect(ids).toContain(disease.id);
|
||||
}
|
||||
expect(getAllDiseaseIds()).not.toContain("unknown");
|
||||
});
|
||||
|
||||
it("has no duplicates", () => {
|
||||
|
||||
@@ -1,74 +1,197 @@
|
||||
/**
|
||||
* Class label mapping for the plant disease classifier model.
|
||||
*
|
||||
* Maps model output index → disease ID string.
|
||||
* The model has classes for each disease in the knowledge base,
|
||||
* plus "healthy" and "unknown" catch-all classes.
|
||||
* This model is a MobileNetV2 trained on the PlantVillage dataset
|
||||
* with 38 classes (14 crops × diseases/healthy).
|
||||
*
|
||||
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 95
|
||||
* (93 diseases + "healthy" + "unknown")
|
||||
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 38
|
||||
*
|
||||
* Index layout:
|
||||
* 0 → "healthy"
|
||||
* 1–93 → disease IDs (order matches diseases.json)
|
||||
* 94 → "unknown"
|
||||
* Index layout (from labels_pv_original.json):
|
||||
* 0 → Apple___Apple_scab
|
||||
* 1 → Apple___Black_rot
|
||||
* 2 → Apple___Cedar_apple_rust
|
||||
* 3 → Apple___healthy
|
||||
* 4 → Blueberry___healthy
|
||||
* 5 → Cherry_(including_sour)___Powdery_mildew
|
||||
* 6 → Cherry_(including_sour)___healthy
|
||||
* 7 → Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
|
||||
* 8 → Corn_(maize)___Common_rust_
|
||||
* 9 → Corn_(maize)___Northern_Leaf_Blight
|
||||
* 10 → Corn_(maize)___healthy
|
||||
* 11 → Grape___Black_rot
|
||||
* 12 → Grape___Esca_(Black_Measles)
|
||||
* 13 → Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
|
||||
* 14 → Grape___healthy
|
||||
* 15 → Orange___Haunglongbing_(Citrus_greening)
|
||||
* 16 → Peach___Bacterial_spot
|
||||
* 17 → Peach___healthy
|
||||
* 18 → Pepper,_bell___Bacterial_spot
|
||||
* 19 → Pepper,_bell___healthy
|
||||
* 20 → Potato___Early_blight
|
||||
* 21 → Potato___Late_blight
|
||||
* 22 → Potato___healthy
|
||||
* 23 → Raspberry___healthy
|
||||
* 24 → Soybean___healthy
|
||||
* 25 → Squash___Powdery_mildew
|
||||
* 26 → Strawberry___Leaf_scorch
|
||||
* 27 → Strawberry___healthy
|
||||
* 28 → Tomato___Bacterial_spot
|
||||
* 29 → Tomato___Early_blight
|
||||
* 30 → Tomato___Late_blight
|
||||
* 31 → Tomato___Leaf_Mold
|
||||
* 32 → Tomato___Septoria_leaf_spot
|
||||
* 33 → Tomato___Spider_mites Two-spotted_spider_mite
|
||||
* 34 → Tomato___Target_Spot
|
||||
* 35 → Tomato___Tomato_Yellow_Leaf_Curl_Virus
|
||||
* 36 → Tomato___Tomato_mosaic_virus
|
||||
* 37 → Tomato___healthy
|
||||
*
|
||||
* Some PlantVillage classes overlap with this app's knowledge base.
|
||||
* Exact class name → disease ID mappings:
|
||||
* Potato___Early_blight → "early-blight"
|
||||
* Potato___Late_blight → "late-blight"
|
||||
* Squash___Powdery_mildew → "squash-powdery-mildew"
|
||||
* Strawberry___Leaf_scorch → "strawberry-leaf-scorch"
|
||||
* Tomato___Bacterial_spot → "bacterial-leaf-spot-tomato"
|
||||
* Tomato___Early_blight → "early-blight"
|
||||
* Tomato___Late_blight → "late-blight"
|
||||
* Tomato___Septoria_leaf_spot → "septoria-leaf-spot"
|
||||
* All other classes map to "unknown" and are filtered out during enrichment.
|
||||
*
|
||||
* After fine-tuning to the app's 93 disease classes, this file will be
|
||||
* rewritten to match the new model's output layer.
|
||||
*/
|
||||
|
||||
import rawDiseases from "@/data/diseases.json";
|
||||
import type { Disease } from "@/lib/types";
|
||||
// ─── PlantVillage class names (in model output order) ────────────────────
|
||||
|
||||
const diseases: Disease[] = rawDiseases as Disease[];
|
||||
const PLANTVILLAGE_CLASSES: string[] = [
|
||||
"Apple___Apple_scab",
|
||||
"Apple___Black_rot",
|
||||
"Apple___Cedar_apple_rust",
|
||||
"Apple___healthy",
|
||||
"Blueberry___healthy",
|
||||
"Cherry_(including_sour)___Powdery_mildew",
|
||||
"Cherry_(including_sour)___healthy",
|
||||
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
|
||||
"Corn_(maize)___Common_rust_",
|
||||
"Corn_(maize)___Northern_Leaf_Blight",
|
||||
"Corn_(maize)___healthy",
|
||||
"Grape___Black_rot",
|
||||
"Grape___Esca_(Black_Measles)",
|
||||
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
|
||||
"Grape___healthy",
|
||||
"Orange___Haunglongbing_(Citrus_greening)",
|
||||
"Peach___Bacterial_spot",
|
||||
"Peach___healthy",
|
||||
"Pepper,_bell___Bacterial_spot",
|
||||
"Pepper,_bell___healthy",
|
||||
"Potato___Early_blight",
|
||||
"Potato___Late_blight",
|
||||
"Potato___healthy",
|
||||
"Raspberry___healthy",
|
||||
"Soybean___healthy",
|
||||
"Squash___Powdery_mildew",
|
||||
"Strawberry___Leaf_scorch",
|
||||
"Strawberry___healthy",
|
||||
"Tomato___Bacterial_spot",
|
||||
"Tomato___Early_blight",
|
||||
"Tomato___Late_blight",
|
||||
"Tomato___Leaf_Mold",
|
||||
"Tomato___Septoria_leaf_spot",
|
||||
"Tomato___Spider_mites Two-spotted_spider_mite",
|
||||
"Tomato___Target_Spot",
|
||||
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
|
||||
"Tomato___Tomato_mosaic_virus",
|
||||
"Tomato___healthy",
|
||||
] as const;
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
/** Index for the "healthy" class */
|
||||
export const HEALTHY_INDEX = 0;
|
||||
|
||||
/** First index for actual disease classes */
|
||||
export const FIRST_DISEASE_INDEX = 1;
|
||||
|
||||
/** Index for the "unknown" catch-all class */
|
||||
export const UNKNOWN_INDEX = 1 + diseases.length;
|
||||
|
||||
/** Total number of output classes */
|
||||
export const NUM_CLASSES = UNKNOWN_INDEX + 1;
|
||||
|
||||
// ─── Index → Disease ID mapping ──────────────────────────────────────────────
|
||||
// ─── PlantVillage → App disease ID mapping ──────────────────────────────
|
||||
|
||||
/**
|
||||
* Map from model output index to disease ID string.
|
||||
* Index 0 = "healthy", indices 1..N = disease IDs, last = "unknown".
|
||||
* Maps PlantVillage class names (in the form "Plant___Disease") to
|
||||
* this app's disease IDs. Unmapped classes resolve to "unknown".
|
||||
*/
|
||||
function plantVillageNameToDiseaseId(pvName: string): string {
|
||||
const parts = pvName.split("___");
|
||||
if (parts.length !== 2) {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
const disease = parts[1];
|
||||
|
||||
// Detect "healthy" variants
|
||||
if (disease === "healthy") {
|
||||
return "healthy";
|
||||
}
|
||||
|
||||
// Map exact PlantVillage class names to our disease IDs.
|
||||
// Only map classes where we're confident the correspondence holds.
|
||||
const exactMap: Record<string, string> = {
|
||||
Squash___Powdery_mildew: "squash-powdery-mildew",
|
||||
Strawberry___Leaf_scorch: "strawberry-leaf-scorch",
|
||||
Potato___Early_blight: "early-blight",
|
||||
Potato___Late_blight: "late-blight",
|
||||
Tomato___Bacterial_spot: "bacterial-leaf-spot-tomato",
|
||||
Tomato___Early_blight: "early-blight",
|
||||
Tomato___Late_blight: "late-blight",
|
||||
Tomato___Septoria_leaf_spot: "septoria-leaf-spot",
|
||||
};
|
||||
|
||||
return exactMap[pvName] ?? "unknown";
|
||||
}
|
||||
|
||||
// ─── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
/** Total number of model output classes */
|
||||
export const NUM_CLASSES = PLANTVILLAGE_CLASSES.length; // 38
|
||||
|
||||
/** Index for the "healthy" class — multiple PV indices map to this */
|
||||
export const HEALTHY_INDEX = 0; // First PV healthy class, others also map to this string
|
||||
|
||||
/** First disease index (unused in PV mapping, kept for compatibility) */
|
||||
export const FIRST_DISEASE_INDEX = 0;
|
||||
|
||||
/** Index for the "unknown" catch-all — PV classes we can't map */
|
||||
export const UNKNOWN_INDEX = NUM_CLASSES - 1; // 37 (Tomato___healthy maps to "healthy", not unknown)
|
||||
|
||||
// ─── Index → Disease ID mapping ─────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Map from model output index to app disease ID string.
|
||||
* Built dynamically from PlantVillage class names.
|
||||
*/
|
||||
export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze(
|
||||
(() => {
|
||||
const map: Record<number, string> = {};
|
||||
map[HEALTHY_INDEX] = "healthy";
|
||||
for (let i = 0; i < diseases.length; i++) {
|
||||
map[FIRST_DISEASE_INDEX + i] = diseases[i].id;
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
map[i] = plantVillageNameToDiseaseId(PLANTVILLAGE_CLASSES[i]);
|
||||
}
|
||||
map[UNKNOWN_INDEX] = "unknown";
|
||||
return map;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ─── Disease ID → Index mapping ──────────────────────────────────────────────
|
||||
// ─── Disease ID → Index mapping ─────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Map from disease ID string to model output index.
|
||||
* For duplicates (e.g., both potato and tomato "Early_blight" → "early-blight"),
|
||||
* returns the first matching index.
|
||||
*/
|
||||
export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze(
|
||||
(() => {
|
||||
const map: Record<string, number> = {};
|
||||
map["healthy"] = HEALTHY_INDEX;
|
||||
for (let i = 0; i < diseases.length; i++) {
|
||||
map[diseases[i].id] = FIRST_DISEASE_INDEX + i;
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
const id = INDEX_TO_DISEASE_ID[i];
|
||||
// First occurrence wins (potato before tomato for early/late blight)
|
||||
if (map[id] === undefined) {
|
||||
map[id] = i;
|
||||
}
|
||||
}
|
||||
map["unknown"] = UNKNOWN_INDEX;
|
||||
return map;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ─── Lookup helpers ──────────────────────────────────────────────────────────
|
||||
// ─── Lookup helpers ─────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get the disease ID for a given model output index.
|
||||
@@ -93,9 +216,22 @@ export function isRealDisease(diseaseId: string): boolean {
|
||||
return diseaseId !== "healthy" && diseaseId !== "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the PlantVillage display name for a given model output index.
|
||||
*/
|
||||
export function getPlantVillageClassName(index: number): string {
|
||||
return PLANTVILLAGE_CLASSES[index] ?? "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all known disease IDs (excluding "healthy" and "unknown").
|
||||
*/
|
||||
export function getAllDiseaseIds(): string[] {
|
||||
return diseases.map((d) => d.id);
|
||||
const ids = new Set<string>();
|
||||
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
|
||||
if (id !== "healthy" && id !== "unknown") {
|
||||
ids.add(id);
|
||||
}
|
||||
}
|
||||
return Array.from(ids);
|
||||
}
|
||||
|
||||
@@ -93,7 +93,10 @@ export async function getModel(): Promise<PlantDiseaseModel> {
|
||||
const model = await Promise.race([
|
||||
loadingPromise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)), MODEL_LOAD_TIMEOUT),
|
||||
setTimeout(
|
||||
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
|
||||
MODEL_LOAD_TIMEOUT,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
@@ -172,6 +175,18 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let tf: any;
|
||||
|
||||
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
|
||||
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
const nodeUtil = require("util");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
|
||||
return x === null || x === undefined;
|
||||
};
|
||||
}
|
||||
|
||||
// Try tfjs-node first (server-side, uses native bindings).
|
||||
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
|
||||
// as required dependencies — they are truly optional.
|
||||
@@ -197,7 +212,9 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
|
||||
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 160, 160])
|
||||
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
|
||||
const inputTensor = tf
|
||||
.tensor3d(Array.from(tensor), [3, 160, 160])
|
||||
.transpose([1, 2, 0])
|
||||
.expandDims(0);
|
||||
|
||||
@@ -352,7 +369,7 @@ function generateMockLogits(tensor: Float32Array): Float32Array {
|
||||
logits[topIndex] = 3.5;
|
||||
|
||||
// Second highest
|
||||
const secondIndex = (topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1) + 1;
|
||||
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
|
||||
logits[secondIndex] = 2.5;
|
||||
|
||||
logits[numClasses - 1] = -2; // "unknown" gets low score
|
||||
|
||||
Reference in New Issue
Block a user