scripting

This commit is contained in:
2026-06-06 15:45:21 -04:00
parent 06295c83ca
commit 47609e5e42
11 changed files with 4411 additions and 205 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,537 @@
#!/usr/bin/env python3
"""
fine-tune-model.py
Fine-tunes the PlantVillage MobileNetV2 model on a custom 95-class dataset
(93 diseases + healthy + unknown).
Pipeline:
1. Load `best_mnv2_pv_original.keras` (MobileNetV2 backbone + 38-class head)
2. Replace the 38-class head with 95 classes (order matches diseases.json + healthy + unknown)
3. Freeze backbone, train only the new classification head
4. Unfreeze the last ~20 layers, fine-tune at lower learning rate
5. Export to TF.js GraphModel format
6. Export to .keras for future retraining
Usage: .tfjs-venv/bin/python scripts/fine-tune-model.py
"""
import json
import os
import sys
import shutil
from pathlib import Path
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF info/warnings
import numpy as np
import tensorflow as tf
import keras
from keras import layers, optimizers, regularizers
# ─── Constants ───────────────────────────────────────────────────────────────
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_PATH = (
PROJECT_ROOT
/ "public"
/ "models"
/ "plant-disease-classifier"
/ "best_mnv2_pv_original.keras"
)
DISEASES_JSON = PROJECT_ROOT / "src" / "data" / "diseases.json"
DATASET_DIR = PROJECT_ROOT / "data" / "dataset"
OUTPUT_DIR = PROJECT_ROOT / "public" / "models" / "plant-disease-classifier"
TFJS_OUTPUT = OUTPUT_DIR / "tfjs_finetuned"
IMG_SIZE = 160 # Model input size
BATCH_SIZE = 32
EPOCHS_HEAD = 15 # Train just the new head
EPOCHS_FINETUNE = 10 # Unfreeze and fine-tune
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_FINETUNE = 1e-5
VALIDATION_SPLIT = 0.15
NUM_CLASSES = 95 # healthy(0) + 93 diseases + unknown(94)
# ─── Class Mapping ───────────────────────────────────────────────────────────
def build_class_mapping():
"""
Build a dict mapping dataset directory names → model class indices.
Matches the ordering in labels.ts / diseases.json.
Index 0 = "healthy"
Index 1-93 = disease IDs (in diseases.json order)
Index 94 = "unknown" (no images — skip during training)
"""
with open(DISEASES_JSON) as f:
diseases = json.load(f)
mapping = {"healthy": 0}
for i, disease in enumerate(diseases):
mapping[disease["id"]] = i + 1 # Index 1-93
mapping["unknown"] = 94 # Not trained, but reserved
# Reverse mapping for predictions
index_to_class = {v: k for k, v in mapping.items()}
return mapping, index_to_class
def verify_dataset(mapping):
"""Find which classes have images and how many."""
available = {}
total = 0
for class_id, class_idx in mapping.items():
class_dir = DATASET_DIR / class_id
if not class_dir.exists():
continue
image_paths = sorted(class_dir.glob("*"))
image_paths = [
p
for p in image_paths
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
if image_paths:
available[class_id] = {"index": class_idx, "count": len(image_paths)}
total += len(image_paths)
return available, total
def print_dataset_summary(available, total):
"""Print a summary of what's available."""
print(f"\n{'' * 60}")
print("DATASET SUMMARY")
print(f"{'' * 60}")
print(f" Total images: {total}")
print(f" Classes found: {len(available)} / {len(build_class_mapping()[0])}")
print(
f" Missing classes with no images: {len(build_class_mapping()[0]) - len(available)}"
)
# Count images per class
counts = [(v["index"], k, v["count"]) for k, v in available.items()]
counts.sort(key=lambda x: x[1])
print("\n Images per class:")
for idx, class_id, count in counts:
label = f" {idx:3d}. {class_id:<35s} {count:>4d} images"
if class_id == "healthy":
label += " ← 2× target"
print(label)
# Stats
class_counts = [v["count"] for v in available.values()]
if class_counts:
print(
f"\n Min: {min(class_counts)} Max: {max(class_counts)} Avg: {sum(class_counts) / len(class_counts):.0f}"
)
print(f"{'' * 60}\n")
# ─── Data Loading ────────────────────────────────────────────────────────────
def load_dataset(mapping, available):
"""
Load images from the dataset directory.
Returns train/validation datasets with augmentation.
"""
# Build file paths and labels
file_paths = []
labels = []
for class_id, info in available.items():
class_dir = DATASET_DIR / class_id
images = sorted(class_dir.glob("*"))
images = [
p for p in images if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp")
]
for img_path in images:
file_paths.append(str(img_path))
labels.append(info["index"])
file_paths = np.array(file_paths)
labels = np.array(labels)
# Shuffle
indices = np.random.RandomState(42).permutation(len(file_paths))
file_paths = file_paths[indices]
labels = labels[indices]
# Split train/validation
split = int(len(file_paths) * (1 - VALIDATION_SPLIT))
train_paths, val_paths = file_paths[:split], file_paths[split:]
train_labels, val_labels = labels[:split], labels[split:]
print(f" Train: {len(train_paths)} images")
print(f" Val: {len(val_paths)} images")
# Parse function
def parse_image(image_path, label):
img = tf.io.read_file(image_path)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
img = tf.cast(img, tf.float32) / 255.0
# ImageNet normalization (matching training-time preprocessing)
mean = tf.constant([0.485, 0.456, 0.406])
std = tf.constant([0.229, 0.224, 0.225])
img = (img - mean) / std
return img, label
def augment(image, label):
"""Data augmentation for training set."""
# Random horizontal flip
image = tf.image.random_flip_left_right(image)
# Random rotation (±20°)
image = tf.image.random_flip_up_down(image)
# Random brightness
image = tf.image.random_brightness(image, 0.15)
# Random contrast
image = tf.image.random_contrast(image, 0.8, 1.2)
# Random saturation
image = tf.image.random_saturation(image, 0.8, 1.2)
# Random hue
image = tf.image.random_hue(image, 0.05)
# Random crop (after slightly scaling up)
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 12, IMG_SIZE + 12)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
# Clip to valid range after augmentations
image = tf.clip_by_value(image, -2.5, 2.5)
return image, label
# Create tf.data datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds
# ─── Model Building ──────────────────────────────────────────────────────────
def build_model():
"""
Load the PlantVillage model and replace the classification head
with a 95-class output.
"""
print(f"\nLoading base model from: {MODEL_PATH}")
if not MODEL_PATH.exists():
print(f"ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
base_model = keras.models.load_model(str(MODEL_PATH))
print(f" Base model loaded: {type(base_model).__name__}")
print(f" Input shape: {base_model.input_shape}")
print(f" Output shape: {base_model.output_shape}")
# Extract backbone — everything up to the GlobalAveragePooling2D
# The model structure is:
# input_layer_2 → mobilenetv2_1.00_160 → global_average_pooling2d → dropout → dense(38)
backbone_output = base_model.get_layer("global_average_pooling2d").output
print(" Using backbone output: global_average_pooling2d")
# Freeze all backbone layers initially
# (we'll unfreeze later for fine-tuning)
for layer in base_model.layers:
if layer.name != "dense": # We'll replace this anyway
layer.trainable = False
# Build new classification head
x = backbone_output
x = layers.Dropout(0.3, name="dropout_new")(x)
x = layers.Dense(
NUM_CLASSES,
activation="softmax",
name="dense_new",
kernel_regularizer=regularizers.l2(1e-4),
)(x)
# Create new model
model = keras.Model(
inputs=base_model.input, outputs=x, name="plant-disease-classifier-v2"
)
print(f" New model input: {model.input_shape}")
print(f" New model output: {model.output_shape} ({NUM_CLASSES} classes)")
# Count trainable params
backbone_trainable = sum(
w.shape.num_elements()
for layer in base_model.layers
if layer.name != "dense"
for w in layer.trainable_weights
)
head_trainable = sum(
w.shape.num_elements() for w in model.get_layer("dense_new").trainable_weights
)
print(f" Backbone frozen: {backbone_trainable:,} params (not training)")
print(f" New head: {head_trainable:,} params (training)")
return model
# ─── Training ────────────────────────────────────────────────────────────────
def train_head(model, train_ds, val_ds):
"""Stage 1: Train only the new classification head."""
print(f"\n{'=' * 60}")
print("STAGE 1: Training classification head")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_HEAD}")
print(f" Learning rate: {LEARNING_RATE_HEAD}")
print(f" Batch size: {BATCH_SIZE}")
# Freeze all backbone layers
for layer in model.layers:
if layer.name != "dense_new":
layer.trainable = False
else:
layer.trainable = True
# Verify
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_HEAD),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_HEAD,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 1 complete! Val accuracy: {final_val_acc:.4f}")
return history
def train_finetune(model, train_ds, val_ds):
"""Stage 2: Unfreeze last ~25 layers and fine-tune."""
print(f"\n{'=' * 60}")
print("STAGE 2: Fine-tuning backbone (last ~25 layers)")
print(f"{'=' * 60}")
print(f" Epochs: {EPOCHS_FINETUNE}")
print(f" Learning rate: {LEARNING_RATE_FINETUNE}")
# Find the MobileNetV2 functional module
# The backbone is a Functional model inside the base model
mobilenet_layer = model.get_layer("mobilenetv2_1.00_160")
# Unfreeze the last ~25 layers of the backbone
total_backbone_layers = len(mobilenet_layer.layers)
unfreeze_from = max(0, total_backbone_layers - 25)
print(
f" Backbone has {total_backbone_layers} layers, unfreezing from layer {unfreeze_from}"
)
for i, layer in enumerate(mobilenet_layer.layers):
layer.trainable = i >= unfreeze_from
# Also unfreeze the new head
model.get_layer("dense_new").trainable = True
model.get_layer("dropout_new").trainable = True
trainable = sum(w.shape.num_elements() for w in model.trainable_weights)
total = sum(w.shape.num_elements() for w in model.weights)
print(f" Trainable params: {trainable:,} / {total:,} total")
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE_FINETUNE),
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_FINETUNE,
verbose=1,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3,
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-7,
),
],
)
final_val_acc = history.history["val_accuracy"][-1]
print(f"\n Stage 2 complete! Val accuracy: {final_val_acc:.4f}")
return history
# ─── Export ──────────────────────────────────────────────────────────────────
def export_models(model, class_mapping, index_to_class):
"""Export the trained model to .keras and TF.js formats."""
print(f"\n{'=' * 60}")
print("EXPORTING")
print(f"{'=' * 60}")
# 1. Save as .keras (for future retraining)
keras_path = OUTPUT_DIR / "model-finetuned.keras"
model.save(str(keras_path))
print(f" ✓ Saved .keras: {keras_path}")
# 2. Save class mapping alongside the model
mapping_path = OUTPUT_DIR / "class_mapping.json"
with open(mapping_path, "w") as f:
json.dump(
{
"index_to_class": index_to_class,
"class_to_index": class_mapping,
"num_classes": NUM_CLASSES,
"input_size": IMG_SIZE,
},
f,
indent=2,
)
print(f" ✓ Saved class mapping: {mapping_path}")
# 3. Export to TF.js format
tfjs_path = str(TFJS_OUTPUT)
if TFJS_OUTPUT.exists():
shutil.rmtree(tfjs_path)
try:
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, tfjs_path)
print(f" ✓ Saved TF.js: {tfjs_path}/")
for f in sorted(TFJS_OUTPUT.iterdir()):
size = f.stat().st_size
print(f" {f.name:<30s} {size:>10,} bytes")
except Exception as e:
print(f" ⚠ TF.js export failed: {e}")
print(
f" Run later: tensorflowjs_converter --input_format=keras {keras_path} {tfjs_path}"
)
# ─── Cleanup Old Model Files ────────────────────────────────────────────────
def cleanup_old_model():
"""Remove old model.json and shards from the directory."""
for f in OUTPUT_DIR.glob("model.json"):
print(f" Removing old: {f.name}")
f.unlink()
for f in OUTPUT_DIR.glob("group1-shard*"):
print(f" Removing old: {f.name}")
f.unlink()
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("PLANT DISEASE MODEL FINE-TUNER")
print("=" * 60)
# 1. Build class mapping
print("\n[1/5] Building class mapping...")
class_mapping, index_to_class = build_class_mapping()
print(
f" {len(class_mapping)} classes defined (0=healthy, 1-93=diseases, 94=unknown)"
)
# 2. Verify dataset
print("\n[2/5] Verifying dataset...")
if not DATASET_DIR.exists():
print(f" ERROR: Dataset not found at {DATASET_DIR}")
print(" Run the scraper first: npx tsx scripts/scrape-training-dataset.ts")
sys.exit(1)
available, total = verify_dataset(class_mapping)
print_dataset_summary(available, total)
if total < 100:
print(f" WARNING: Only {total} images. Consider scraping more data.")
print(" Continue anyway? (y/n)")
# Continue regardless — user can decide
# 3. Load dataset
print("\n[3/5] Loading and augmenting dataset...")
train_ds, val_ds = load_dataset(class_mapping, available)
# 4. Build and train model
print("\n[4/5] Building model...")
model = build_model()
model.summary()
# Check if training should run
if total > 0:
train_head(model, train_ds, val_ds)
train_finetune(model, train_ds, val_ds)
# 5. Export
print("\n[5/5] Exporting models...")
cleanup_old_model()
export_models(model, class_mapping, index_to_class)
else:
print("\n Skipping training — no dataset available.")
sys.exit(1)
# ── Final Summary ────────────────────────────────────────────────────────
print(f"\n{'=' * 60}")
print("DONE! Model fine-tuned and exported.")
print(f"{'=' * 60}")
print("\nFiles created:")
print(f" {OUTPUT_DIR / 'model-finetuned.keras'}")
print(f" {OUTPUT_DIR / 'class_mapping.json'}")
print(f" {TFJS_OUTPUT / 'model.json'}")
print("\nTo update your app:")
print(" 1. Replace model files:")
print(f" cp {TFJS_OUTPUT / 'model.json'} {OUTPUT_DIR / 'model.json'}")
print(f" cp {TFJS_OUTPUT / 'group1-shard*'} {OUTPUT_DIR / '/'}")
print(" 2. Restart the dev server")
print(" 3. Test with: POST /api/identify")
print("\nNote: Update labels.ts if the class order changed.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,660 @@
#!/usr/bin/env node
/**
* scrape-training-dataset.ts
*
* Collects a training dataset for fine-tuning by scraping DuckDuckGo image search.
*
* Targets:
* - 200 images per disease class (93 diseases)
* - 400 images for the "healthy" class
* - Full resolution images stored in data/dataset/{class_id}/
*
* DuckDuckGo approach (no API key needed):
* 1. Fetch the main search page to extract a vqd (query) token
* 2. Use the vqd token to paginate through image results
* 3. Download each image to the dataset directory
*
* Usage: cd apps/web && npx tsx scripts/scrape-training-dataset.ts
*
* Progress is tracked in data/dataset/.progress.json — interrupt and resume safely.
*/
import "dotenv/config";
import { readFileSync, writeFileSync, existsSync, mkdirSync, readdirSync, statSync } from "fs";
import { resolve, extname, join } from "path";
// ─── Config ─────────────────────────────────────────────────────────────────
const DISEASES_JSON = resolve(__dirname, "../src/data/diseases.json");
const PLANTS_JSON = resolve(__dirname, "../src/data/plants.json");
const DATASET_DIR = resolve(__dirname, "../data/dataset");
const PROGRESS_FILE = resolve(DATASET_DIR, ".progress.json");
/** Target images per disease class */
const TARGET_PER_DISEASE = 200;
/** Target images for the "healthy" class (2× normal) */
const TARGET_HEALTHY = 400;
/** Delay between DuckDuckGo search API calls (ms) */
const SEARCH_DELAY = 1500;
/** Delay between image downloads (ms) */
const DOWNLOAD_DELAY = 300;
/** Max concurrent downloads */
const CONCURRENT_DOWNLOADS = 5;
/** Minimum image size in bytes to accept (reject tiny placeholders) */
const MIN_IMAGE_SIZE = 10_000; // 10KB
/** Maximum image size in bytes */
const MAX_IMAGE_SIZE = 10 * 1024 * 1024; // 10MB
/** Allowed image content types */
const ALLOWED_CONTENT_TYPES = ["image/jpeg", "image/jpg", "image/png", "image/webp", "image/gif"];
/** Allowed file extensions */
const ALLOWED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"];
/** User agent for requests */
const UA =
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36";
// ─── Types ──────────────────────────────────────────────────────────────────
interface DiseaseSeed {
id: string;
plantId: string;
name: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[key: string]: any;
}
interface PlantSeed {
id: string;
commonName: string;
scientificName: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[key: string]: any;
}
interface DuckDuckGoImageResult {
image: string;
title: string;
url: string;
thumbnail: string;
height: number;
width: number;
}
interface ClassProgress {
count: number;
downloaded: number;
failed: number;
skipped: number;
/** URLs we've already seen (to avoid duplicates) */
seenUrls: string[];
/** Whether we've exhausted search results */
exhausted: boolean;
}
interface Progress {
lastUpdated: string;
classes: Record<string, ClassProgress>;
}
/** Class ID for healthy plants */
const HEALTHY_CLASS = "healthy";
// ─── DuckDuckGo API ─────────────────────────────────────────────────────────
/**
* Extract the vqd token from DuckDuckGo's search page.
* Required for paginating image results.
*/
async function getVqdToken(query: string): Promise<string> {
const url = `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`;
const res = await fetch(url, {
headers: { "User-Agent": UA, Accept: "text/html" },
signal: AbortSignal.timeout(15_000),
});
if (!res.ok) {
throw new Error(`Failed to get vqd token: ${res.status}`);
}
const html = await res.text();
// Extract vqd token from the HTML
// Format: vqd='<token>' or vqd="<token>"
const match = html.match(/vqd['"]?\s*[:=]\s*['"]([a-f0-9-]+)['"]/);
if (!match) {
throw new Error(`Could not extract vqd token from DuckDuckGo response for "${query}"`);
}
return match[1];
}
/**
* Fetch a page of DuckDuckGo image results.
*/
async function searchImagesDuckDuckGo(
query: string,
vqd: string,
page: number,
): Promise<DuckDuckGoImageResult[]> {
const url = `https://duckduckgo.com/i.js?q=${encodeURIComponent(query)}&vqd=${vqd}&o=json&p=${page}&f=,,,`;
const res = await fetch(url, {
headers: {
"User-Agent": UA,
Accept: "application/json",
Referer: `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`,
},
signal: AbortSignal.timeout(15_000),
});
if (!res.ok) {
if (res.status === 429) {
console.warn(" ⚠ Rate limited (429). Waiting 10s...");
await sleep(10_000);
return searchImagesDuckDuckGo(query, vqd, page); // Retry
}
if (res.status === 403) {
console.warn(" ⚠ Forbidden (403). Token may have expired.");
return []; // Token expired — no more pages
}
throw new Error(`DuckDuckGo search failed: ${res.status}`);
}
const data = (await res.json()) as { results: DuckDuckGoImageResult[] };
return data.results ?? [];
}
/**
* Search DuckDuckGo images, automatically paginating to collect up to `target` results.
* Returns unique image URLs.
*/
async function collectImages(
query: string,
target: number,
seenUrls: Set<string>,
): Promise<{ urls: string[]; exhausted: boolean }> {
const results: string[] = [];
let page = 1;
let exhausted = false;
let consecutiveEmpty = 0;
// Get vqd token
let vqd: string;
try {
vqd = await getVqdToken(query);
} catch (err) {
console.warn(` ⚠ Failed to get vqd token: ${err instanceof Error ? err.message : "unknown"}`);
return { urls: [], exhausted: true };
}
while (results.length < target) {
await sleep(SEARCH_DELAY);
let pageResults: DuckDuckGoImageResult[];
try {
pageResults = await searchImagesDuckDuckGo(query, vqd, page);
} catch (err) {
console.warn(` ⚠ Search error: ${err instanceof Error ? err.message : "unknown"}`);
break;
}
if (pageResults.length === 0) {
consecutiveEmpty++;
if (consecutiveEmpty >= 3) {
exhausted = true;
break;
}
page++;
continue;
}
consecutiveEmpty = 0;
let newCount = 0;
for (const r of pageResults) {
if (results.length >= target) break;
const imgUrl = r.image || r.url;
// Skip if we've already seen this URL
if (seenUrls.has(imgUrl)) continue;
// Validate URL looks like an image
const ext = extname(new URL(imgUrl).pathname).toLowerCase();
if (!ALLOWED_EXTENSIONS.includes(ext) && !ext) {
// No extension - still try, could be a CDN URL
}
seenUrls.add(imgUrl);
results.push(imgUrl);
newCount++;
}
if (newCount === 0 && pageResults.every((r) => seenUrls.has(r.image || r.url))) {
// All results on this page were already seen
page++;
continue;
}
if (results.length < target) {
page++;
}
}
return { urls: results.slice(0, target), exhausted };
}
// ─── Image Download ─────────────────────────────────────────────────────────
/**
* Download a single image from a URL to the target path.
* Returns true if successful, false otherwise.
*/
async function downloadImage(url: string, destPath: string): Promise<boolean> {
try {
const res = await fetch(url, {
headers: { "User-Agent": UA, Accept: "image/webp,image/png,image/jpeg" },
signal: AbortSignal.timeout(15_000),
});
if (!res.ok) return false;
const contentType = res.headers.get("content-type") || "";
const contentLength = parseInt(res.headers.get("content-length") || "0", 10);
// Validate content type
if (!ALLOWED_CONTENT_TYPES.some((t) => contentType.includes(t))) {
return false;
}
// Validate size
if (contentLength > 0 && contentLength < MIN_IMAGE_SIZE) return false;
if (contentLength > MAX_IMAGE_SIZE) return false;
const buffer = Buffer.from(await res.arrayBuffer());
// Double-check actual buffer size
if (buffer.length < MIN_IMAGE_SIZE) return false;
if (buffer.length > MAX_IMAGE_SIZE) return false;
// Determine correct extension from content type or URL
let ext = extname(new URL(url).pathname).toLowerCase();
if (!ALLOWED_EXTENSIONS.includes(ext)) {
// Map from content type
if (contentType.includes("jpeg") || contentType.includes("jpg")) ext = ".jpg";
else if (contentType.includes("png")) ext = ".png";
else if (contentType.includes("webp")) ext = ".webp";
else ext = ".jpg"; // Default
}
const filePath = destPath.replace(/\.\w+$/, ext);
writeFileSync(filePath, buffer);
return true;
} catch {
return false;
}
}
/**
* Download multiple images concurrently, respecting a per-download delay.
*/
async function downloadBatch(
urls: string[],
classDir: string,
startIndex: number,
): Promise<{ downloaded: number; failed: number; lastIndex: number }> {
let downloaded = 0;
let failed = 0;
let index = startIndex;
// Process in chunks to control concurrency
for (let i = 0; i < urls.length; i += CONCURRENT_DOWNLOADS) {
const chunk = urls.slice(i, i + CONCURRENT_DOWNLOADS);
const results = await Promise.all(
chunk.map(async (url) => {
const paddedIndex = String(index).padStart(4, "0");
const destPath = resolve(classDir, `img_${paddedIndex}.jpg`);
const success = await downloadImage(url, destPath);
await sleep(DOWNLOAD_DELAY);
return { success, index: index++ };
}),
);
for (const r of results) {
if (r.success) downloaded++;
else failed++;
}
}
return { downloaded, failed, lastIndex: index };
}
// ─── Progress Tracking ──────────────────────────────────────────────────────
function loadProgress(): Progress {
if (!existsSync(PROGRESS_FILE)) {
return { lastUpdated: new Date().toISOString(), classes: {} };
}
return JSON.parse(readFileSync(PROGRESS_FILE, "utf-8")) as Progress;
}
function saveProgress(progress: Progress): void {
progress.lastUpdated = new Date().toISOString();
writeFileSync(PROGRESS_FILE, JSON.stringify(progress, null, 2));
}
function getClassProgress(progress: Progress, classId: string): ClassProgress {
if (!progress.classes[classId]) {
progress.classes[classId] = {
count: 0,
downloaded: 0,
failed: 0,
skipped: 0,
seenUrls: [],
exhausted: false,
};
}
return progress.classes[classId];
}
// ─── Search Query Building ──────────────────────────────────────────────────
function buildSearchQueries(disease: DiseaseSeed, plant: PlantSeed | null): string[] {
const name = disease.name;
const plantName = plant?.commonName || disease.plantId;
return [
`${name} ${plantName} leaf disease`,
`${plantName} ${name} symptoms`,
`${name} plant disease`,
`${plantName} diseased leaf`,
];
}
function buildHealthyQueries(plant: PlantSeed): string[] {
return [
`healthy ${plant.commonName} leaf`,
`${plant.commonName} leaf closeup`,
`healthy ${plant.commonName} plant`,
`${plant.commonName} foliage`,
];
}
// ─── Dataset Collection ─────────────────────────────────────────────────────
async function collectClassImages(
classId: string,
queries: string[],
target: number,
progress: Progress,
classDir: string,
): Promise<void> {
const cp = getClassProgress(progress, classId);
const seenUrls = new Set(cp.seenUrls);
if (cp.count >= target) {
console.log(` ✓ Already have ${cp.count}/${target} images`);
return;
}
if (cp.exhausted) {
console.log(` ✓ Already exhausted search results (${cp.count}/${target} images)`);
return;
}
mkdirSync(classDir, { recursive: true });
const totalUrls: string[] = [];
let exhausted = false;
// Search with each query until we hit the target
for (const query of queries) {
if (totalUrls.length >= target) break;
console.log(` Searching: "${query}"...`);
const result = await collectImages(query, target - totalUrls.length, seenUrls);
totalUrls.push(...result.urls);
cp.seenUrls = Array.from(seenUrls);
if (result.exhausted) {
exhausted = true;
}
if (totalUrls.length >= target) break;
}
if (totalUrls.length === 0) {
cp.exhausted = exhausted;
saveProgress(progress);
console.log(` ✗ No images found for "${classId}"`);
return;
}
console.log(` Found ${totalUrls.length} unique image URLs. Downloading...`);
// Download the images
const { downloaded, failed } = await downloadBatch(totalUrls, classDir, cp.count);
cp.count += downloaded;
cp.downloaded += downloaded;
cp.failed += failed;
cp.exhausted = exhausted;
saveProgress(progress);
const pct = Math.round((cp.count / target) * 100);
console.log(
` ${downloaded > 0 ? "✓" : "✗"} Got ${downloaded} images (${failed} failed). Total: ${cp.count}/${target} (${pct}%)`,
);
}
// ─── Main ───────────────────────────────────────────────────────────────────
async function main() {
console.log("=".repeat(60));
console.log("PLANT DISEASE DATASET COLLECTOR");
console.log("=".repeat(60));
// Load knowledge base
const diseases = JSON.parse(readFileSync(DISEASES_JSON, "utf-8")) as DiseaseSeed[];
const plants = JSON.parse(readFileSync(PLANTS_JSON, "utf-8")) as PlantSeed[];
const plantMap = new Map<string, PlantSeed>();
for (const p of plants) {
plantMap.set(p.id, p);
}
console.log(`\nLoaded ${diseases.length} diseases, ${plants.length} plants`);
console.log(
`Target: ${TARGET_PER_DISEASE} images/disease (×${diseases.length} = ${diseases.length * TARGET_PER_DISEASE})`,
);
console.log(`Target: ${TARGET_HEALTHY} images for "healthy" class`);
console.log(`Output: ${DATASET_DIR}/`);
console.log("");
// Load progress
mkdirSync(DATASET_DIR, { recursive: true });
const progress = loadProgress();
const startTime = Date.now();
// ── Phase 1: Disease classes ──────────────────────────────────────────────
console.log("─".repeat(60));
console.log("PHASE 1: Disease Images");
console.log("─".repeat(60));
for (let i = 0; i < diseases.length; i++) {
const disease = diseases[i];
const plant = plantMap.get(disease.plantId) ?? null;
const classDir = resolve(DATASET_DIR, disease.id);
const queries = buildSearchQueries(disease, plant);
const pct = Math.round((i / diseases.length) * 100);
console.log(`\n[${i + 1}/${diseases.length}] (${pct}%) ${disease.name} (${disease.id})`);
await collectClassImages(disease.id, queries, TARGET_PER_DISEASE, progress, classDir);
}
// ── Phase 2: Healthy class ────────────────────────────────────────────────
console.log("\n" + "─".repeat(60));
console.log("PHASE 2: Healthy Plant Images");
console.log("─".repeat(60));
const healthyDir = resolve(DATASET_DIR, HEALTHY_CLASS);
const healthyCp = getClassProgress(progress, HEALTHY_CLASS);
const healthySeen = new Set(healthyCp.seenUrls);
if (healthyCp.count >= TARGET_HEALTHY) {
console.log(`\n ✓ Already have ${healthyCp.count}/${TARGET_HEALTHY} healthy images`);
} else {
// Build a pool of healthy plant queries
const allHealthyQueries: string[] = [];
for (const plant of plants) {
allHealthyQueries.push(...buildHealthyQueries(plant));
}
const totalHealthyUrls: string[] = [];
let healthyExhausted = false;
for (const query of allHealthyQueries) {
if (totalHealthyUrls.length >= TARGET_HEALTHY) break;
if (healthyExhausted) break;
console.log(`\n Searching: "${query}"...`);
const result = await collectImages(
query,
TARGET_HEALTHY - totalHealthyUrls.length,
healthySeen,
);
totalHealthyUrls.push(...result.urls);
if (result.exhausted) {
healthyExhausted = true;
}
}
healthyCp.seenUrls = Array.from(healthySeen);
if (totalHealthyUrls.length > 0) {
console.log(`\n Found ${totalHealthyUrls.length} healthy image URLs. Downloading...`);
const { downloaded, failed } = await downloadBatch(
totalHealthyUrls,
healthyDir,
healthyCp.count,
);
healthyCp.count += downloaded;
healthyCp.downloaded += downloaded;
healthyCp.failed += failed;
healthyCp.exhausted = healthyExhausted;
const pct = Math.round((healthyCp.count / TARGET_HEALTHY) * 100);
console.log(
` Got ${downloaded} images (${failed} failed). Total: ${healthyCp.count}/${TARGET_HEALTHY} (${pct}%)`,
);
} else {
healthyCp.exhausted = true;
console.log(` ✗ No healthy images found`);
}
saveProgress(progress);
}
// ── Summary ────────────────────────────────────────────────────────────────
const elapsed = Math.round((Date.now() - startTime) / 1000);
const mins = Math.floor(elapsed / 60);
const secs = elapsed % 60;
let totalDownloaded = 0;
let totalFailed = 0;
let totalTarget = 0;
for (const [classId, cp] of Object.entries(progress.classes)) {
totalDownloaded += cp.downloaded || 0;
totalFailed += cp.failed || 0;
totalTarget += classId === HEALTHY_CLASS ? TARGET_HEALTHY : TARGET_PER_DISEASE;
}
const totalSize = await getDatasetSize();
const sizeGb = (totalSize / (1024 * 1024 * 1024)).toFixed(2);
console.log("\n" + "=".repeat(60));
console.log("COMPLETE");
console.log("=".repeat(60));
console.log(` Time: ${mins}m ${secs}s`);
console.log(` Downloaded: ${totalDownloaded} images`);
console.log(` Failed: ${totalFailed} images`);
console.log(` Target: ${totalTarget} images`);
console.log(` Dataset size: ${sizeGb} GB`);
console.log(` Dataset location: ${DATASET_DIR}/`);
console.log("");
console.log("Next steps:");
console.log(" 1. Run the fine-tuning script to train on this dataset");
console.log(" 2. The fine-tuning script will resize to 160×160 and augment");
console.log("=".repeat(60));
}
/**
* Calculate total size of the dataset directory.
*/
async function getDatasetSize(): Promise<number> {
let total = 0;
if (!existsSync(DATASET_DIR)) return 0;
const entries = readdirSync(DATASET_DIR, { withFileTypes: true });
for (const entry of entries) {
if (!entry.name.startsWith(".")) {
const fullPath = resolve(DATASET_DIR, entry.name);
if (entry.isDirectory()) {
total += dirSize(fullPath);
}
}
}
return total;
}
function dirSize(dirPath: string): number {
let total = 0;
try {
const entries = readdirSync(dirPath, { withFileTypes: true });
for (const entry of entries) {
const fullPath = join(dirPath, entry.name);
if (entry.isFile()) {
total += statSync(fullPath).size;
} else if (entry.isDirectory()) {
total += dirSize(fullPath);
}
}
} catch {
// skip errors
}
return total;
}
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
main().catch((err) => {
console.error("Fatal error:", err);
process.exit(1);
});

View File

@@ -173,14 +173,14 @@ describe("imageToTensor", () => {
describe("tensorToBase64 / base64ToTensor", () => { describe("tensorToBase64 / base64ToTensor", () => {
it("round-trips tensor data correctly", () => { it("round-trips tensor data correctly", () => {
const imageData = createMockImageData(224, 224, 100, 150, 200); const imageData = createMockImageData(160, 160, 100, 150, 200);
const original = imageToTensor(imageData); const original = imageToTensor(imageData);
const base64 = tensorToBase64(original); const base64 = tensorToBase64(original);
const decoded = base64ToTensor(base64); const decoded = base64ToTensor(base64);
expect(decoded.tensor.length).toBe(original.length); expect(decoded.tensor.length).toBe(original.length);
expect(decoded.shape).toEqual([3, 224, 224]); expect(decoded.shape).toEqual([3, 160, 160]);
// Check a few values match // Check a few values match
for (let i = 0; i < 10; i++) { for (let i = 0; i < 10; i++) {
@@ -197,9 +197,9 @@ describe("tensorToBase64 / base64ToTensor", () => {
}); });
describe("getTensorShape", () => { describe("getTensorShape", () => {
it("returns [1, 3, 224, 224] by default", () => { it("returns [1, 3, 160, 160] by default", () => {
const shape = getTensorShape(); const shape = getTensorShape();
expect(shape).toEqual([1, 3, 224, 224]); expect(shape).toEqual([1, 3, 160, 160]);
}); });
it("returns NCHW layout", () => { it("returns NCHW layout", () => {
@@ -207,8 +207,8 @@ describe("getTensorShape", () => {
expect(shape.length).toBe(4); expect(shape.length).toBe(4);
expect(shape[0]).toBe(1); // batch expect(shape[0]).toBe(1); // batch
expect(shape[1]).toBe(3); // channels expect(shape[1]).toBe(3); // channels
expect(shape[2]).toBe(224); // height expect(shape[2]).toBe(160); // height (model input size)
expect(shape[3]).toBe(224); // width expect(shape[3]).toBe(160); // width (model input size)
}); });
}); });

View File

@@ -17,7 +17,7 @@
const DEFAULT_MODEL_SIZE = 160; const DEFAULT_MODEL_SIZE = 160;
const DEFAULT_MEAN = [0.485, 0.456, 0.406] as const; // ImageNet RGB means const DEFAULT_MEAN = [0.485, 0.456, 0.406] as const; // ImageNet RGB means
const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds
function getConfig(): { function getConfig(): {
size: number; size: number;
@@ -48,12 +48,7 @@ export const MAX_FILE_SIZE = 10 * 1024 * 1024;
export const MIN_DIMENSION = 150; export const MIN_DIMENSION = 150;
/** Allowed MIME types */ /** Allowed MIME types */
export const ALLOWED_MIME_TYPES = [ export const ALLOWED_MIME_TYPES = ["image/png", "image/jpeg", "image/jpg", "image/webp"] as const;
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
] as const;
export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number]; export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number];
@@ -66,9 +61,7 @@ export const MAX_UPLOADS = 100;
* Validate that a file is an acceptable image for upload. * Validate that a file is an acceptable image for upload.
* Returns `{ ok: true }` or `{ ok: false, error: string }`. * Returns `{ ok: true }` or `{ ok: false, error: string }`.
*/ */
export function validateImageFile(file: File): export function validateImageFile(file: File): { ok: true } | { ok: false; error: string } {
| { ok: true }
| { ok: false; error: string } {
// MIME type check // MIME type check
if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) { if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) {
return { return {
@@ -127,10 +120,7 @@ export function validateImageDimensions(
* @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224. * @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224.
* @returns ImageData at exactly `size × size` * @returns ImageData at exactly `size × size`
*/ */
export async function resizeImage( export async function resizeImage(file: File, size: number = getConfig().size): Promise<ImageData> {
file: File,
size: number = getConfig().size,
): Promise<ImageData> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const img = new Image(); const img = new Image();
img.onload = () => { img.onload = () => {
@@ -193,8 +183,7 @@ export function imageToTensor(imageData: ImageData): Float32Array {
// Normalize with ImageNet mean/std // Normalize with ImageNet mean/std
for (let c = 0; c < 3; c++) { for (let c = 0; c < 3; c++) {
const channel = const channel = c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
const m = mean[c]; const m = mean[c];
const s = std[c]; const s = std[c];
for (let i = 0; i < totalPixels; i++) { for (let i = 0; i < totalPixels; i++) {
@@ -253,5 +242,3 @@ export function base64ToTensor(base64: string): {
shape: envelope.shape as [number, number, number], shape: envelope.shape as [number, number, number],
}; };
} }

View File

@@ -97,7 +97,7 @@ describe("createZeroTensor", () => {
it("all values are zero", () => { it("all values are zero", () => {
const tensor = createZeroTensor(); const tensor = createZeroTensor();
expect(tensor.every(v => v === 0)).toBe(true); expect(tensor.every((v) => v === 0)).toBe(true);
}); });
}); });
@@ -114,12 +114,12 @@ describe("createRandomTensor", () => {
it("all values are finite", () => { it("all values are finite", () => {
const tensor = createRandomTensor(); const tensor = createRandomTensor();
expect(tensor.every(v => Number.isFinite(v))).toBe(true); expect(tensor.every((v) => Number.isFinite(v))).toBe(true);
}); });
it("produces varied values", () => { it("produces varied values", () => {
const tensor = createRandomTensor(); const tensor = createRandomTensor();
const uniqueValues = new Set(tensor.map(v => v.toFixed(4))); const uniqueValues = new Set(tensor.map((v) => v.toFixed(4)));
expect(uniqueValues.size).toBeGreaterThan(100); expect(uniqueValues.size).toBeGreaterThan(100);
}); });
@@ -172,7 +172,7 @@ describe("runInference", () => {
const result = await runInference(tensor); const result = await runInference(tensor);
for (let i = 0; i < result.predictions.length - 1; i++) { for (let i = 0; i < result.predictions.length - 1; i++) {
expect(result.predictions[i].probability).toBeGreaterThanOrEqual( expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
result.predictions[i + 1].probability result.predictions[i + 1].probability,
); );
} }
}, 10000); }, 10000);

View File

@@ -69,9 +69,7 @@ export async function runInference(
*/ */
export function validateInput(tensor: Float32Array): void { export function validateInput(tensor: Float32Array): void {
if (!(tensor instanceof Float32Array)) { if (!(tensor instanceof Float32Array)) {
throw new Error( throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
`Expected Float32Array input, got ${typeof tensor}`,
);
} }
if (tensor.length !== INPUT_SIZE) { if (tensor.length !== INPUT_SIZE) {
@@ -84,9 +82,7 @@ export function validateInput(tensor: Float32Array): void {
// Check for NaN/Infinity values // Check for NaN/Infinity values
for (let i = 0; i < tensor.length; i++) { for (let i = 0; i < tensor.length; i++) {
if (!Number.isFinite(tensor[i])) { if (!Number.isFinite(tensor[i])) {
throw new Error( throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
`Tensor contains non-finite value at index ${i}: ${tensor[i]}`,
);
} }
} }
} }

View File

@@ -1,17 +1,21 @@
/** /**
* Unit tests for lib/ml/labels.ts * Unit tests for lib/ml/labels.ts
* *
* Tests: * The model has 38 PlantVillage classes. Some map to the app's
* - INDEX_TO_DISEASE_ID maps index 0 to "healthy" * knowledge base disease IDs, others map to "unknown".
* - INDEX_TO_DISEASE_ID maps last index to "unknown" *
* - INDEX_TO_DISEASE_ID maps intermediate indices to disease IDs * Known mappings:
* - DISEASE_ID_TO_INDEX is inverse of INDEX_TO_DISEASE_ID * - indices 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37 → "healthy"
* - getDiseaseIdForIndex returns "unknown" for out-of-range * - index 20 (Potato___Early_blight) → "early-blight"
* - getIndexForDiseaseId returns -1 for unknown ID * - index 21 (Potato___Late_blight) → "late-blight"
* - isRealDisease correctly classifies healthy/unknown vs real diseases * - index 25 (Squash___Powdery_mildew) → "squash-powdery-mildew"
* - getAllDiseaseIds returns all disease IDs from knowledge base * - index 26 (Strawberry___Leaf_scorch) → "strawberry-leaf-scorch"
* - NUM_CLASSES equals expected count (diseases + healthy + unknown) * - index 28 (Tomato___Bacterial_spot) → "bacterial-leaf-spot-tomato"
* - Bidirectional mapping consistency * - index 29 (Tomato___Early_blight) → "early-blight" (duplicate)
* - index 30 (Tomato___Late_blight) → "late-blight" (duplicate)
* - index 32 (Tomato___Septoria_leaf_spot) → "septoria-leaf-spot"
* - index 37 (Tomato___healthy) → "healthy"
* - all others → "unknown"
*/ */
import { describe, it, expect } from "vitest"; import { describe, it, expect } from "vitest";
@@ -23,143 +27,105 @@ import {
isRealDisease, isRealDisease,
getAllDiseaseIds, getAllDiseaseIds,
NUM_CLASSES, NUM_CLASSES,
HEALTHY_INDEX, getPlantVillageClassName,
FIRST_DISEASE_INDEX,
UNKNOWN_INDEX,
} from "@/lib/ml/labels"; } from "@/lib/ml/labels";
import rawDiseases from "@/data/diseases.json";
import type { Disease } from "@/lib/types";
const diseases: Disease[] = rawDiseases as Disease[];
describe("Constants", () => { describe("Constants", () => {
it("HEALTHY_INDEX is 0", () => { it("NUM_CLASSES is 38 (PlantVillage)", () => {
expect(HEALTHY_INDEX).toBe(0); expect(NUM_CLASSES).toBe(38);
}); });
it("FIRST_DISEASE_INDEX is 1", () => { it("all 38 indices are mapped", () => {
expect(FIRST_DISEASE_INDEX).toBe(1); const keys = Object.keys(INDEX_TO_DISEASE_ID).map(Number);
}); expect(keys.length).toBe(38);
for (let i = 0; i < 38; i++) {
it("UNKNOWN_INDEX is 1 + number of diseases", () => { expect(keys).toContain(i);
expect(UNKNOWN_INDEX).toBe(1 + diseases.length); }
});
it("NUM_CLASSES is UNKNOWN_INDEX + 1", () => {
expect(NUM_CLASSES).toBe(UNKNOWN_INDEX + 1);
});
it("NUM_CLASSES equals diseases.length + 2 (healthy + unknown)", () => {
expect(NUM_CLASSES).toBe(diseases.length + 2);
}); });
}); });
describe("INDEX_TO_DISEASE_ID", () => { describe("INDEX_TO_DISEASE_ID — healthy indices", () => {
it("maps index 0 to 'healthy'", () => { const healthyIndices = [3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37];
expect(INDEX_TO_DISEASE_ID[0]).toBe("healthy");
});
it("maps last index to 'unknown'", () => { for (const idx of healthyIndices) {
expect(INDEX_TO_DISEASE_ID[NUM_CLASSES - 1]).toBe("unknown"); it(`index ${idx} maps to "healthy"`, () => {
}); expect(INDEX_TO_DISEASE_ID[idx]).toBe("healthy");
});
}
});
it("maps intermediate indices to disease IDs", () => { describe("INDEX_TO_DISEASE_ID — known disease mappings", () => {
// Index 1 should be the first disease const cases: Array<{ index: number; expected: string; name: string }> = [
expect(INDEX_TO_DISEASE_ID[1]).toBe(diseases[0].id); { index: 20, expected: "early-blight", name: "Potato___Early_blight" },
// Index 2 should be the second disease { index: 21, expected: "late-blight", name: "Potato___Late_blight" },
expect(INDEX_TO_DISEASE_ID[2]).toBe(diseases[1].id); { index: 25, expected: "squash-powdery-mildew", name: "Squash___Powdery_mildew" },
// Last disease index { index: 26, expected: "strawberry-leaf-scorch", name: "Strawberry___Leaf_scorch" },
expect(INDEX_TO_DISEASE_ID[diseases.length]).toBe(diseases[diseases.length - 1].id); { index: 28, expected: "bacterial-leaf-spot-tomato", name: "Tomato___Bacterial_spot" },
}); { index: 29, expected: "early-blight", name: "Tomato___Early_blight" },
{ index: 30, expected: "late-blight", name: "Tomato___Late_blight" },
{ index: 32, expected: "septoria-leaf-spot", name: "Tomato___Septoria_leaf_spot" },
];
it("has exactly NUM_CLASSES entries", () => { for (const { index, expected, name } of cases) {
const keys = Object.keys(INDEX_TO_DISEASE_ID); it(`index ${index} (${name}) maps to "${expected}"`, () => {
expect(keys.length).toBe(NUM_CLASSES); expect(INDEX_TO_DISEASE_ID[index]).toBe(expected);
}); });
}
});
it("all mapped IDs are valid strings", () => { describe("INDEX_TO_DISEASE_ID — unknown (unmapped) indices", () => {
for (const id of Object.values(INDEX_TO_DISEASE_ID)) { const unknownIndices = [0, 1, 2, 5, 7, 8, 9, 11, 12, 13, 15, 16, 18, 31, 33, 34, 35, 36];
expect(typeof id).toBe("string");
expect(id.length).toBeGreaterThan(0); for (const idx of unknownIndices) {
} it(`index ${idx} maps to "unknown"`, () => {
}); expect(INDEX_TO_DISEASE_ID[idx]).toBe("unknown");
});
}
}); });
describe("DISEASE_ID_TO_INDEX", () => { describe("DISEASE_ID_TO_INDEX", () => {
it("maps 'healthy' to index 0", () => { it("maps 'early-blight' to first occurrence (index 20)", () => {
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(0); expect(DISEASE_ID_TO_INDEX["early-blight"]).toBe(20);
}); });
it("maps 'unknown' to last index", () => { it("maps 'late-blight' to first occurrence (index 21)", () => {
expect(DISEASE_ID_TO_INDEX["unknown"]).toBe(NUM_CLASSES - 1); expect(DISEASE_ID_TO_INDEX["late-blight"]).toBe(21);
}); });
it("maps disease IDs to correct indices", () => { it("maps 'septoria-leaf-spot' to index 32", () => {
for (let i = 0; i < diseases.length; i++) { expect(DISEASE_ID_TO_INDEX["septoria-leaf-spot"]).toBe(32);
expect(DISEASE_ID_TO_INDEX[diseases[i].id]).toBe(FIRST_DISEASE_INDEX + i);
}
}); });
it("has exactly NUM_CLASSES entries", () => { it("maps 'healthy' to index 3 (first healthy index)", () => {
const keys = Object.keys(DISEASE_ID_TO_INDEX); expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(3);
expect(keys.length).toBe(NUM_CLASSES);
}); });
}); });
describe("Bidirectional mapping", () => { describe("Bidirectional mapping", () => {
it("INDEX_TO_DISEASE_ID and DISEASE_ID_TO_INDEX are inverses", () => { it("every index round-trips correctly", () => {
for (const [idxStr, id] of Object.entries(INDEX_TO_DISEASE_ID)) {
const idx = parseInt(idxStr);
expect(DISEASE_ID_TO_INDEX[id]).toBe(idx);
}
});
it("round-trips for all disease IDs", () => {
for (const [id, idx] of Object.entries(DISEASE_ID_TO_INDEX)) {
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
}
});
it("round-trips for all indices", () => {
for (let i = 0; i < NUM_CLASSES; i++) { for (let i = 0; i < NUM_CLASSES; i++) {
const id = INDEX_TO_DISEASE_ID[i]; const id = INDEX_TO_DISEASE_ID[i];
expect(DISEASE_ID_TO_INDEX[id]).toBe(i); const idx = DISEASE_ID_TO_INDEX[id];
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
} }
}); });
}); });
describe("getDiseaseIdForIndex", () => { describe("getDiseaseIdForIndex", () => {
it("returns 'healthy' for index 0", () => {
expect(getDiseaseIdForIndex(0)).toBe("healthy");
});
it("returns disease ID for valid disease index", () => {
expect(getDiseaseIdForIndex(1)).toBe(diseases[0].id);
});
it("returns 'unknown' for out-of-range positive index", () => { it("returns 'unknown' for out-of-range positive index", () => {
expect(getDiseaseIdForIndex(1000)).toBe("unknown"); expect(getDiseaseIdForIndex(100)).toBe("unknown");
}); });
it("returns 'unknown' for negative index", () => { it("returns 'unknown' for negative index", () => {
expect(getDiseaseIdForIndex(-1)).toBe("unknown"); expect(getDiseaseIdForIndex(-1)).toBe("unknown");
}); });
it("returns 'unknown' for index past NUM_CLASSES", () => { it("returns correct ID for valid index", () => {
expect(getDiseaseIdForIndex(NUM_CLASSES + 10)).toBe("unknown"); expect(getDiseaseIdForIndex(20)).toBe("early-blight");
}); });
}); });
describe("getIndexForDiseaseId", () => { describe("getIndexForDiseaseId", () => {
it("returns 0 for 'healthy'", () => {
expect(getIndexForDiseaseId("healthy")).toBe(0);
});
it("returns correct index for known disease", () => {
const idx = getIndexForDiseaseId(diseases[0].id);
expect(idx).toBe(1);
});
it("returns -1 for unknown disease ID", () => { it("returns -1 for unknown disease ID", () => {
expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1); expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1);
}); });
@@ -169,9 +135,7 @@ describe("getIndexForDiseaseId", () => {
}); });
it("is case-insensitive", () => { it("is case-insensitive", () => {
const lowerIdx = getIndexForDiseaseId(diseases[0].id); expect(getIndexForDiseaseId("EARLY-BLIGHT")).toBe(20);
const upperIdx = getIndexForDiseaseId(diseases[0].id.toUpperCase());
expect(upperIdx).toBe(lowerIdx);
}); });
}); });
@@ -184,10 +148,9 @@ describe("isRealDisease", () => {
expect(isRealDisease("unknown")).toBe(false); expect(isRealDisease("unknown")).toBe(false);
}); });
it("returns true for actual disease IDs", () => { it("returns true for known disease IDs", () => {
for (const disease of diseases) { expect(isRealDisease("early-blight")).toBe(true);
expect(isRealDisease(disease.id)).toBe(true); expect(isRealDisease("septoria-leaf-spot")).toBe(true);
}
}); });
it("returns true for arbitrary non-special strings", () => { it("returns true for arbitrary non-special strings", () => {
@@ -195,27 +158,37 @@ describe("isRealDisease", () => {
}); });
}); });
describe("getPlantVillageClassName", () => {
it("returns correct class name for tomato healthy", () => {
expect(getPlantVillageClassName(37)).toBe("Tomato___healthy");
});
it("returns correct class name for potato early blight", () => {
expect(getPlantVillageClassName(20)).toBe("Potato___Early_blight");
});
it("returns 'unknown' for out-of-range index", () => {
expect(getPlantVillageClassName(100)).toBe("unknown");
});
});
describe("getAllDiseaseIds", () => { describe("getAllDiseaseIds", () => {
it("returns array of all disease IDs", () => { it("returns only mapped disease IDs", () => {
const ids = getAllDiseaseIds(); const ids = getAllDiseaseIds();
expect(ids.length).toBe(diseases.length); expect(ids).toContain("early-blight");
expect(ids).toContain("late-blight");
expect(ids).toContain("squash-powdery-mildew");
expect(ids).toContain("strawberry-leaf-scorch");
expect(ids).toContain("bacterial-leaf-spot-tomato");
expect(ids).toContain("septoria-leaf-spot");
}); });
it("excludes 'healthy'", () => { it("excludes 'healthy'", () => {
const ids = getAllDiseaseIds(); expect(getAllDiseaseIds()).not.toContain("healthy");
expect(ids).not.toContain("healthy");
}); });
it("excludes 'unknown'", () => { it("excludes 'unknown'", () => {
const ids = getAllDiseaseIds(); expect(getAllDiseaseIds()).not.toContain("unknown");
expect(ids).not.toContain("unknown");
});
it("includes all disease IDs from knowledge base", () => {
const ids = getAllDiseaseIds();
for (const disease of diseases) {
expect(ids).toContain(disease.id);
}
}); });
it("has no duplicates", () => { it("has no duplicates", () => {

View File

@@ -1,74 +1,197 @@
/** /**
* Class label mapping for the plant disease classifier model. * Class label mapping for the plant disease classifier model.
* *
* Maps model output index → disease ID string. * This model is a MobileNetV2 trained on the PlantVillage dataset
* The model has classes for each disease in the knowledge base, * with 38 classes (14 crops × diseases/healthy).
* plus "healthy" and "unknown" catch-all classes.
* *
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 95 * Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 38
* (93 diseases + "healthy" + "unknown")
* *
* Index layout: * Index layout (from labels_pv_original.json):
* 0 → "healthy" * 0 → Apple___Apple_scab
* 193 → disease IDs (order matches diseases.json) * 1 → Apple___Black_rot
* 94"unknown" * 2 Apple___Cedar_apple_rust
* 3 → Apple___healthy
* 4 → Blueberry___healthy
* 5 → Cherry_(including_sour)___Powdery_mildew
* 6 → Cherry_(including_sour)___healthy
* 7 → Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
* 8 → Corn_(maize)___Common_rust_
* 9 → Corn_(maize)___Northern_Leaf_Blight
* 10 → Corn_(maize)___healthy
* 11 → Grape___Black_rot
* 12 → Grape___Esca_(Black_Measles)
* 13 → Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
* 14 → Grape___healthy
* 15 → Orange___Haunglongbing_(Citrus_greening)
* 16 → Peach___Bacterial_spot
* 17 → Peach___healthy
* 18 → Pepper,_bell___Bacterial_spot
* 19 → Pepper,_bell___healthy
* 20 → Potato___Early_blight
* 21 → Potato___Late_blight
* 22 → Potato___healthy
* 23 → Raspberry___healthy
* 24 → Soybean___healthy
* 25 → Squash___Powdery_mildew
* 26 → Strawberry___Leaf_scorch
* 27 → Strawberry___healthy
* 28 → Tomato___Bacterial_spot
* 29 → Tomato___Early_blight
* 30 → Tomato___Late_blight
* 31 → Tomato___Leaf_Mold
* 32 → Tomato___Septoria_leaf_spot
* 33 → Tomato___Spider_mites Two-spotted_spider_mite
* 34 → Tomato___Target_Spot
* 35 → Tomato___Tomato_Yellow_Leaf_Curl_Virus
* 36 → Tomato___Tomato_mosaic_virus
* 37 → Tomato___healthy
*
* Some PlantVillage classes overlap with this app's knowledge base.
* Exact class name → disease ID mappings:
* Potato___Early_blight → "early-blight"
* Potato___Late_blight → "late-blight"
* Squash___Powdery_mildew → "squash-powdery-mildew"
* Strawberry___Leaf_scorch → "strawberry-leaf-scorch"
* Tomato___Bacterial_spot → "bacterial-leaf-spot-tomato"
* Tomato___Early_blight → "early-blight"
* Tomato___Late_blight → "late-blight"
* Tomato___Septoria_leaf_spot → "septoria-leaf-spot"
* All other classes map to "unknown" and are filtered out during enrichment.
*
* After fine-tuning to the app's 93 disease classes, this file will be
* rewritten to match the new model's output layer.
*/ */
import rawDiseases from "@/data/diseases.json"; // ─── PlantVillage class names (in model output order) ────────────────────
import type { Disease } from "@/lib/types";
const diseases: Disease[] = rawDiseases as Disease[]; const PLANTVILLAGE_CLASSES: string[] = [
"Apple___Apple_scab",
"Apple___Black_rot",
"Apple___Cedar_apple_rust",
"Apple___healthy",
"Blueberry___healthy",
"Cherry_(including_sour)___Powdery_mildew",
"Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
"Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___healthy",
"Grape___Black_rot",
"Grape___Esca_(Black_Measles)",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
"Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot",
"Peach___healthy",
"Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy",
"Potato___Early_blight",
"Potato___Late_blight",
"Potato___healthy",
"Raspberry___healthy",
"Soybean___healthy",
"Squash___Powdery_mildew",
"Strawberry___Leaf_scorch",
"Strawberry___healthy",
"Tomato___Bacterial_spot",
"Tomato___Early_blight",
"Tomato___Late_blight",
"Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus",
"Tomato___healthy",
] as const;
// ─── Constants ─────────────────────────────────────────────────────────────── // ─── PlantVillage → App disease ID mapping ──────────────────────────────
/** Index for the "healthy" class */
export const HEALTHY_INDEX = 0;
/** First index for actual disease classes */
export const FIRST_DISEASE_INDEX = 1;
/** Index for the "unknown" catch-all class */
export const UNKNOWN_INDEX = 1 + diseases.length;
/** Total number of output classes */
export const NUM_CLASSES = UNKNOWN_INDEX + 1;
// ─── Index → Disease ID mapping ──────────────────────────────────────────────
/** /**
* Map from model output index to disease ID string. * Maps PlantVillage class names (in the form "Plant___Disease") to
* Index 0 = "healthy", indices 1..N = disease IDs, last = "unknown". * this app's disease IDs. Unmapped classes resolve to "unknown".
*/
function plantVillageNameToDiseaseId(pvName: string): string {
const parts = pvName.split("___");
if (parts.length !== 2) {
return "unknown";
}
const disease = parts[1];
// Detect "healthy" variants
if (disease === "healthy") {
return "healthy";
}
// Map exact PlantVillage class names to our disease IDs.
// Only map classes where we're confident the correspondence holds.
const exactMap: Record<string, string> = {
Squash___Powdery_mildew: "squash-powdery-mildew",
Strawberry___Leaf_scorch: "strawberry-leaf-scorch",
Potato___Early_blight: "early-blight",
Potato___Late_blight: "late-blight",
Tomato___Bacterial_spot: "bacterial-leaf-spot-tomato",
Tomato___Early_blight: "early-blight",
Tomato___Late_blight: "late-blight",
Tomato___Septoria_leaf_spot: "septoria-leaf-spot",
};
return exactMap[pvName] ?? "unknown";
}
// ─── Constants ──────────────────────────────────────────────────────────
/** Total number of model output classes */
export const NUM_CLASSES = PLANTVILLAGE_CLASSES.length; // 38
/** Index for the "healthy" class — multiple PV indices map to this */
export const HEALTHY_INDEX = 0; // First PV healthy class, others also map to this string
/** First disease index (unused in PV mapping, kept for compatibility) */
export const FIRST_DISEASE_INDEX = 0;
/** Index for the "unknown" catch-all — PV classes we can't map */
export const UNKNOWN_INDEX = NUM_CLASSES - 1; // 37 (Tomato___healthy maps to "healthy", not unknown)
// ─── Index → Disease ID mapping ─────────────────────────────────────────
/**
* Map from model output index to app disease ID string.
* Built dynamically from PlantVillage class names.
*/ */
export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze( export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze(
(() => { (() => {
const map: Record<number, string> = {}; const map: Record<number, string> = {};
map[HEALTHY_INDEX] = "healthy"; for (let i = 0; i < NUM_CLASSES; i++) {
for (let i = 0; i < diseases.length; i++) { map[i] = plantVillageNameToDiseaseId(PLANTVILLAGE_CLASSES[i]);
map[FIRST_DISEASE_INDEX + i] = diseases[i].id;
} }
map[UNKNOWN_INDEX] = "unknown";
return map; return map;
})(), })(),
); );
// ─── Disease ID → Index mapping ────────────────────────────────────────────── // ─── Disease ID → Index mapping ─────────────────────────────────────────
/** /**
* Map from disease ID string to model output index. * Map from disease ID string to model output index.
* For duplicates (e.g., both potato and tomato "Early_blight" → "early-blight"),
* returns the first matching index.
*/ */
export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze( export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze(
(() => { (() => {
const map: Record<string, number> = {}; const map: Record<string, number> = {};
map["healthy"] = HEALTHY_INDEX; for (let i = 0; i < NUM_CLASSES; i++) {
for (let i = 0; i < diseases.length; i++) { const id = INDEX_TO_DISEASE_ID[i];
map[diseases[i].id] = FIRST_DISEASE_INDEX + i; // First occurrence wins (potato before tomato for early/late blight)
if (map[id] === undefined) {
map[id] = i;
}
} }
map["unknown"] = UNKNOWN_INDEX;
return map; return map;
})(), })(),
); );
// ─── Lookup helpers ────────────────────────────────────────────────────────── // ─── Lookup helpers ─────────────────────────────────────────────────────
/** /**
* Get the disease ID for a given model output index. * Get the disease ID for a given model output index.
@@ -93,9 +216,22 @@ export function isRealDisease(diseaseId: string): boolean {
return diseaseId !== "healthy" && diseaseId !== "unknown"; return diseaseId !== "healthy" && diseaseId !== "unknown";
} }
/**
* Get the PlantVillage display name for a given model output index.
*/
export function getPlantVillageClassName(index: number): string {
return PLANTVILLAGE_CLASSES[index] ?? "unknown";
}
/** /**
* Get all known disease IDs (excluding "healthy" and "unknown"). * Get all known disease IDs (excluding "healthy" and "unknown").
*/ */
export function getAllDiseaseIds(): string[] { export function getAllDiseaseIds(): string[] {
return diseases.map((d) => d.id); const ids = new Set<string>();
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
if (id !== "healthy" && id !== "unknown") {
ids.add(id);
}
}
return Array.from(ids);
} }

View File

@@ -93,7 +93,10 @@ export async function getModel(): Promise<PlantDiseaseModel> {
const model = await Promise.race([ const model = await Promise.race([
loadingPromise, loadingPromise,
new Promise<never>((_, reject) => new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)), MODEL_LOAD_TIMEOUT), setTimeout(
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
MODEL_LOAD_TIMEOUT,
),
), ),
]); ]);
@@ -172,6 +175,18 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
let tf: any; let tf: any;
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
// eslint-disable-next-line @typescript-eslint/no-require-imports
const nodeUtil = require("util");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
return x === null || x === undefined;
};
}
// Try tfjs-node first (server-side, uses native bindings). // Try tfjs-node first (server-side, uses native bindings).
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these // Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
// as required dependencies — they are truly optional. // as required dependencies — they are truly optional.
@@ -197,7 +212,9 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
const startTime = performance.now(); const startTime = performance.now();
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js // Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 160, 160]) // Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
const inputTensor = tf
.tensor3d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0]) .transpose([1, 2, 0])
.expandDims(0); .expandDims(0);
@@ -352,7 +369,7 @@ function generateMockLogits(tensor: Float32Array): Float32Array {
logits[topIndex] = 3.5; logits[topIndex] = 3.5;
// Second highest // Second highest
const secondIndex = (topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1) + 1; const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
logits[secondIndex] = 2.5; logits[secondIndex] = 2.5;
logits[numClasses - 1] = -2; // "unknown" gets low score logits[numClasses - 1] = -2; // "unknown" gets low score