diff --git a/checkpoints/species_warmup/species_warmup_epoch=04.pt b/checkpoints/species_warmup/species_warmup_epoch=04.pt new file mode 100644 index 0000000..2bfcdf5 Binary files /dev/null and b/checkpoints/species_warmup/species_warmup_epoch=04.pt differ diff --git a/data/.gitignore b/data/.gitignore index caba867..8283297 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -1 +1,2 @@ dataset +organized diff --git a/launch-training.sh b/launch-training.sh new file mode 100755 index 0000000..d2971b2 --- /dev/null +++ b/launch-training.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Launch hierarchical training — Stage A (species classifier) +# Batch 1536 to utilize the 78GB free memory + +cd /home/mike/Plant-Health-ID +source .venv/bin/activate +mkdir -p logs checkpoints + +nohup python3 scripts/train_hierarchical.py \ + --stage species \ + --batch-size 512 \ + --no-wandb \ + >logs/species_training.log 2>&1 & + +echo "Training launched (PID: $!)" +echo "Monitor with: tail -f logs/species_training.log" diff --git a/logs/species_training.log b/logs/species_training.log new file mode 100644 index 0000000..af91594 --- /dev/null +++ b/logs/species_training.log @@ -0,0 +1,33 @@ + GPU: AMD Radeon 8060S Graphics + Memory: 133.1 GB + ROCm version: 7.2.53211 + +============================================================ +Preparing data... +============================================================ + [train] 1291202 images, 531 species + [val] 234551 images, 531 species + Train: 1291202 images, 2521 batches + Val: 234551 images, 230 batches + +============================================================ +Building model... +============================================================ + Model: Swin-Tiny backbone (768-dim features) + Species head: 531 classes + Disease heads: 521 species-specific heads + +============================================================ +Stage A — Species Classifier Training +============================================================ + +── Sub-stage A1: Head warmup (backbone frozen) ── + species_warmup: 0%| | 0/2521 [00:00 (plant, disease) lookup +dir_to_class = {} +for plant, disease in reparsed: + # Approximate the original dir name + key = f"{plant}-{disease}" + dir_to_class[key] = (plant, disease) + +plant_image_totals = Counter() +for d, size in class_sizes.items(): + # Find matching entry + # The dir name might not exactly match the key + if d in dir_to_class: + plant, disease = dir_to_class[d] + plant_image_totals[plant] += size + +print(f"\nTop 15 plants by total images:") +for plant, cnt in plant_image_totals.most_common(15): + print(f" {plant}: {cnt:,}") diff --git a/scripts/organize-dataset.py b/scripts/organize-dataset.py new file mode 100644 index 0000000..6a08c45 --- /dev/null +++ b/scripts/organize-dataset.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +""" +Phase 1 — Dataset Reorganization for Hierarchical Model Training. + +Reorganizes flat data/dataset/plant-disease-name/ directories into: + data/organized/ + train/{species}/{disease}/ + val/{species}/{disease}/ + species_index.json + class_hierarchy.json + dataset_stats.json + +Usage: python3 scripts/organize-dataset.py +""" + +import json +import os +import random +from collections import Counter, defaultdict +from pathlib import Path + +from PIL import Image +from joblib import Parallel, delayed +from tqdm import tqdm + +# ─── Config ─────────────────────────────────────────────────────────────────── + +BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +DATASET_DIR = BASE_DIR / "data" / "dataset" +ORGANIZED_DIR = BASE_DIR / "data" / "organized" +TRAIN_DIR = ORGANIZED_DIR / "train" +VAL_DIR = ORGANIZED_DIR / "val" + +RANDOM_SEED = 42 +TRAIN_RATIO = 0.85 +VAL_RATIO = 1.0 - TRAIN_RATIO + +MAX_DIM = 512 +JPEG_QUALITY = 90 +N_JOBS = 16 + +random.seed(RANDOM_SEED) + +# Known disease-prefix words — words that start disease names but should NOT +# be part of a plant name. If a plant part ends with one of these, we know +# the split point is wrong. +DISEASE_PREFIX_WORDS = { + "bacterial", "fungal", "viral", "downy", "powdery", + "alternaria", "phytophthora", "phoma", "phymatotrichum", + "pythium", "rhizoctonia", "sclerotinia", "fusarium", + "verticillium", "cercospora", "septoria", "anthracnose", + "black", "white", "gray", "brown", "green", "pink", "blue", + "soft", "hard", "sour", "bitter", + "southern", "northern", "common", "false", "true", + "european", "american", "aspen", "bacterial-blight", + "cercospora-leaf", "septoria-leaf", "alternaria-leaf", +} + +# Valid multi-word plant suffixes (these CAN follow a hyphen in plant names) +VALID_MULTI_WORD_PLANTS = { + "squash", "bean", "berry", "apple", "fern", "tree", "vine", + "cactus", "grass", "weed", "mint", "root", "seed", "leaf", + "flower", "fruit", "bark", "wood", "nut", "pea", "lily", + "rose", "moss", "palm", "fern", "orchid", "fig", "cress", + "plant", "sage", "thyme", "leaf-fig", "nest-fern", "tongue", + "tail", "ear", "eye", "nut-tree", "bean-tree", +} + +IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"} + +# ─── Load KB Data ───────────────────────────────────────────────────────────── + +def load_kb(): + with open(BASE_DIR / "src" / "data" / "plants.json") as f: + plants = json.load(f) + with open(BASE_DIR / "src" / "data" / "diseases.json") as f: + diseases = json.load(f) + return plants, diseases + +PLANTS, DISEASES = load_kb() +KB_PLANT_IDS = {p["id"] for p in PLANTS} + +def get_dataset_dirs(): + """Get all non-hidden subdirectories in the dataset folder.""" + dirs = sorted([ + d for d in os.listdir(DATASET_DIR) + if os.path.isdir(DATASET_DIR / d) and not d.startswith(".") + ]) + return dirs + +def count_images(path): + """Count image files in a directory.""" + if not path.exists(): + return 0 + return len([ + f for f in os.listdir(path) + if os.path.isfile(path / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS + ]) + +# ─── Phase 1: Parse directory names ──────────────────────────────────────────── + +def build_plant_and_disease_dictionaries(dirs): + """ + Build verified plant names and disease suffixes from the dataset. + Returns (parsed_dict, unmatched_list). + """ + # Phase 1: Verify plant names from prefixes that appear with >=3 diseases + plant_candidates = defaultdict(set) + for d in dirs: + parts = d.split("-") + if len(parts) < 2: + continue + for split in range(1, min(len(parts), 6)): + plant = "-".join(parts[:split]) + disease = "-".join(parts[split:]) + if plant and disease and len(disease) > 2: + plant_candidates[plant].add(disease) + + verified_plants = set(KB_PLANT_IDS) + for plant, diseases in plant_candidates.items(): + if len(diseases) >= 3 and plant not in verified_plants: + verified_plants.add(plant) + + print(f" Verified plants: {len(verified_plants)} ({len(verified_plants & KB_PLANT_IDS)} from KB)") + + # Phase 2: Match dirs by plant prefix (longest plant first) + sorted_plants = sorted(verified_plants, key=len, reverse=True) + plant_matched = {} + not_matched = [] + + for d in dirs: + matched = False + for plant in sorted_plants: + prefix = plant + "-" + if d.startswith(prefix): + disease = d[len(prefix):] + if disease: + plant_matched[d] = (plant, disease) + matched = True + break + if not matched: + if d.endswith("-healthy"): + plant = d[:-len("-healthy")] + plant_matched[d] = (plant, "healthy") + else: + not_matched.append(d) + + # Collect disease suffixes from Phase 2 matches + disease_suffixes = set(p[1] for p in plant_matched.values()) + print(f" Plant-matched dirs: {len(plant_matched)}, disease suffixes: {len(disease_suffixes)}") + + # Phase 3: Match remaining dirs by disease suffix (longest suffix first) + sorted_disease_suffixes = sorted(disease_suffixes, key=len, reverse=True) + still_not_matched = [] + + for d in not_matched: + matched = False + for suffix in sorted_disease_suffixes: + if d.endswith("-" + suffix): + plant_part = d[:-len("-" + suffix)] + if plant_part and not plant_part.endswith("-"): + plant_matched[d] = (plant_part, suffix) + matched = True + break + if not matched: + still_not_matched.append(d) + + print(f" Phase 3 matched: {len(not_matched) - len(still_not_matched)}") + print(f" Phase 3 remaining: {len(still_not_matched)}") + + # Phase 4: Handle trailing-hyphen dirs and healthy parent dir + final_unmatched = [] + for d in still_not_matched: + if d.endswith("-"): + plant = d[:-1] + if plant: + plant_matched[d] = (plant, "unlabeled") + elif d == "healthy": + healthy_dir = DATASET_DIR / "healthy" + if healthy_dir.exists(): + plant_subdirs = [ + s for s in os.listdir(healthy_dir) + if os.path.isdir(healthy_dir / s) and not s.startswith(".") + ] + for sub_plant in plant_subdirs: + # Use healthy/{sub_plant} as key so we know where to find the images + plant_matched[f"healthy/{sub_plant}"] = (sub_plant, "healthy") + print(f" Healthy dir: {len(plant_subdirs)} per-plant healthy classes") + else: + final_unmatched.append(d) + + print(f" Phase 4 handled {len(still_not_matched) - len(final_unmatched)} edge cases") + print(f" Final unmatched: {len(final_unmatched)}") + if final_unmatched: + print(f" E.g.: {final_unmatched[:10]}") + + # Phase 5: Post-processing — fix species names that ate disease-prefix words + fix_count = 0 + for d in list(plant_matched.keys()): + if d.startswith("healthy/"): + continue # Skip healthy subdirs — these are correct + species, disease = plant_matched[d] + parts = species.split("-") + if len(parts) >= 2 and parts[-1] in DISEASE_PREFIX_WORDS: + # Move the last word from species to disease + new_species = "-".join(parts[:-1]) + new_disease = parts[-1] + "-" + disease + plant_matched[d] = (new_species, new_disease) + fix_count += 1 + + print(f" Post-process fixes (species ending with disease-prefix): {fix_count}") + + return plant_matched, final_unmatched + +# ─── Image Processing ──────────────────────────────────────────────────────── + +def process_image(args): + """Resize and convert a single image to 512px max JPEG q90.""" + src_path, dst_path = args + try: + img = Image.open(src_path) + if img.mode != "RGB": + img = img.convert("RGB") + w, h = img.size + if max(w, h) > MAX_DIM: + ratio = MAX_DIM / max(w, h) + img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS) + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + img.save(dst_path, "JPEG", quality=JPEG_QUALITY, optimize=True) + return (src_path, True, None) + except Exception as e: + return (src_path, False, str(e)) + +def copy_and_split_class(src_dir, dst_train_dir, dst_val_dir, train_ratio=TRAIN_RATIO): + """ + Copy images from src_dir to train/val dirs, splitting at the IMAGE level. + Returns (train_processed, train_failed, val_processed, val_failed). + """ + # Check both possible source paths (regular dir or healthy subdir) + if not src_dir.exists(): + return (0, 0, 0, 0) + + src_files = sorted([ + f for f in os.listdir(src_dir) + if os.path.isfile(src_dir / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS + ]) + if not src_files: + return (0, 0, 0, 0) + + # Split files at IMAGE level + random.shuffle(src_files) + split_idx = max(1, int(len(src_files) * train_ratio)) + train_files = src_files[:split_idx] + val_files = src_files[split_idx:] + + # Process train images + train_pairs = [ + (str(src_dir / f), str(dst_train_dir / f"img_{i:04d}.jpg")) + for i, f in enumerate(train_files) + ] + val_pairs = [ + (str(src_dir / f), str(dst_val_dir / f"img_{i:04d}.jpg")) + for i, f in enumerate(val_files) + ] + + results = Parallel(n_jobs=N_JOBS, prefer="threads")( + delayed(process_image)(pair) for pair in train_pairs + val_pairs + ) + + train_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i < len(train_pairs)) + train_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i < len(train_pairs)) + val_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i >= len(train_pairs)) + val_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i >= len(train_pairs)) + + return (train_ok, train_fail, val_ok, val_fail) + +# ─── Build Metadata ────────────────────────────────────────────────────────── + +def build_metadata(parsed, train_counts, val_counts, unmatched): + """Build species_index.json, class_hierarchy.json, dataset_stats.json.""" + species_disease_map = defaultdict(set) + for species, disease in parsed.values(): + species_disease_map[species].add(disease) + species_index = {k: sorted(v) for k, v in sorted(species_disease_map.items())} + + class_hierarchy = { + "version": "1.0", + "description": "Hierarchical plant disease classification dataset", + "num_species": len(species_index), + "num_classes": len(parsed), + "species": {species: sorted(diseases) for species, diseases in species_index.items()} + } + + # Aggregate counts + total_train = sum(cnt for sp, di, cnt in train_counts) + total_val = sum(cnt for sp, di, cnt in val_counts) + total_all = total_train + total_val + + all_counts = [cnt for _, _, cnt in (train_counts + val_counts)] + + species_disease_counts = defaultdict(lambda: defaultdict(int)) + for sp, di, cnt in train_counts + val_counts: + species_disease_counts[sp][di] += cnt + + # Also count classes from the parsed dict (unique species/disease combos) + parsed_classes = set((sp, di) for sp, di in parsed.values()) + + stats = { + "total_images": total_all, + "total_species": len(species_index), + "total_classes": len(parsed_classes), + "train_images": total_train, + "val_images": total_val, + "images_per_class": { + "min": min(all_counts) if all_counts else 0, + "max": max(all_counts) if all_counts else 0, + "mean": round(sum(all_counts) / len(all_counts)) if all_counts else 0, + "median": sorted(all_counts)[len(all_counts) // 2] if all_counts else 0, + }, + "train_pct": round(total_train / total_all * 100, 1) if total_all else 0, + "val_pct": round(total_val / total_all * 100, 1) if total_all else 0, + "unmatched_dirs": len(unmatched), + "unmatched_dir_names": unmatched[:100] if unmatched else [], + "species_disease_counts": { + species: dict(diseases) for species, diseases in species_disease_counts.items() + } + } + + return species_index, class_hierarchy, stats + +# ─── Main Pipeline ─────────────────────────────────────────────────────────── + +def main(): + print("=" * 60) + print("Phase 1 — Dataset Reorganization") + print("=" * 60) + print(f"Dataset: {DATASET_DIR}") + print(f"Output: {ORGANIZED_DIR}") + print() + + # Step 1: Scan + print("─" * 40) + print("Step 1: Scanning dataset directories...") + print("─" * 40) + dirs = get_dataset_dirs() + print(f" Found {len(dirs)} class directories") + + # Step 2: Parse directory names into (species, disease) pairs + print() + print("─" * 40) + print("Step 2: Parsing directory names...") + print("─" * 40) + parsed, unmatched = build_plant_and_disease_dictionaries(dirs) + + species_set = set(s for s, _ in parsed.values()) + disease_set = set(d for _, d in parsed.values()) + raw_classes = len(parsed) + unique_classes = len(set((s, d) for s, d in parsed.values())) + print(f"\n Parsed: {raw_classes} entries") + print(f" Unique species: {len(species_set)}") + print(f" Unique disease labels: {len(disease_set)}") + print(f" Unique (species, disease) pairs: {unique_classes}") + + # Step 3: Process images with image-level train/val split + print() + print("─" * 40) + print("Step 3: Processing images (resize + train/val split)...") + print(f" Max dimension: {MAX_DIM}px, JPEG q{JPEG_QUALITY}") + print(f" Workers: {N_JOBS}") + print(f" Split: {TRAIN_RATIO*100:.0f}/{VAL_RATIO*100:.0f} (image-level)") + print("─" * 40) + + train_counts = [] # (species, disease, count) + val_counts = [] + total_skipped = 0 + + # Process regular dirs + regular_items = [(d, sp, di) for d, (sp, di) in parsed.items() + if not d.startswith("healthy/") and d in dirs] + healthy_items = [(d, sp, di) for d, (sp, di) in parsed.items() + if d.startswith("healthy/")] + + # Organize healthy items by plant + healthy_by_plant = {} + for d, sp, di in healthy_items: + healthy_by_plant[sp] = d # d is like "healthy/tomato" + + print(f"\n Processing {len(regular_items)} disease + {len(healthy_items)} healthy classes...") + + for d, species, disease in tqdm(regular_items, desc=" Disease classes"): + src_dir = DATASET_DIR / d + dst_train = TRAIN_DIR / species / disease + dst_val = VAL_DIR / species / disease + + # Skip if already done (check a few files) + if dst_train.exists() and dst_val.exists() and \ + len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir): + total_skipped += count_images(src_dir) + continue + + tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val) + train_counts.append((species, disease, tr_ok)) + val_counts.append((species, disease, va_ok)) + + # Process healthy subdirs + for sp, hkey in tqdm(healthy_by_plant.items(), desc=" Healthy classes"): + src_dir = DATASET_DIR / hkey # e.g. data/dataset/healthy/tomato + dst_train = TRAIN_DIR / sp / "healthy" + dst_val = VAL_DIR / sp / "healthy" + + if dst_train.exists() and dst_val.exists() and \ + len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir): + total_skipped += count_images(src_dir) + continue + + tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val) + train_counts.append((sp, "healthy", tr_ok)) + val_counts.append((sp, "healthy", va_ok)) + + total_train = sum(c for _, _, c in train_counts) + total_val = sum(c for _, _, c in val_counts) + print(f"\n Train images: {total_train:,}") + print(f" Val images: {total_val:,}") + print(f" Skipped previously processed: {total_skipped:,}") + + # Step 4: Build metadata + print() + print("─" * 40) + print("Step 4: Building metadata files...") + print("─" * 40) + ORGANIZED_DIR.mkdir(parents=True, exist_ok=True) + + species_index, class_hierarchy, stats = build_metadata( + parsed, train_counts, val_counts, unmatched + ) + + with open(ORGANIZED_DIR / "species_index.json", "w") as f: + json.dump(species_index, f, indent=2) + print(f" ✓ species_index.json ({len(species_index)} species)") + + with open(ORGANIZED_DIR / "class_hierarchy.json", "w") as f: + json.dump(class_hierarchy, f, indent=2) + print(f" ✓ class_hierarchy.json") + + with open(ORGANIZED_DIR / "dataset_stats.json", "w") as f: + json.dump(stats, f, indent=2) + print(f" ✓ dataset_stats.json") + + # Summary + print() + print("=" * 60) + print("Done!") + print("=" * 60) + print(f" Total images: {stats['total_images']:,}") + print(f" Species: {stats['total_species']}") + print(f" Classes: {stats['total_classes']}") + print(f" Train: {stats['train_images']:,} ({stats['train_pct']}%)") + print(f" Val: {stats['val_images']:,} ({stats['val_pct']}%)") + print(f" Unmatched dirs: {stats['unmatched_dirs']}") + print(f" Train dir: {TRAIN_DIR}") + print(f" Val dir: {VAL_DIR}") + + if stats['unmatched_dirs'] > 0: + print(f"\n ⚠ Manual review needed for {stats['unmatched_dirs']} dirs:") + for u in stats['unmatched_dir_names'][:20]: + print(f" {u}") + + return stats + +if __name__ == "__main__": + main() diff --git a/scripts/train_hierarchical.py b/scripts/train_hierarchical.py new file mode 100644 index 0000000..6a5ba2d --- /dev/null +++ b/scripts/train_hierarchical.py @@ -0,0 +1,1037 @@ +#!/usr/bin/env python3 +""" +Phase 2 — Hierarchical Model Training (Swin-Tiny). + +Trains a hierarchical plant disease classifier with: + - Swin-Tiny backbone (timm, ImageNet-21K pretrained) + - Species head (~530 classes — identifies the plant) + - Disease heads (one per species, 30-300 classes each — identifies the disease) + +Two-stage training protocol: + Stage A (Species): Train species classifier with frozen backbone, then full fine-tune. + Stage B (Disease): Train disease-specific heads, then end-to-end joint training. + +Usage: + python3 scripts/train_hierarchical.py # Full training + python3 scripts/train_hierarchical.py --stage species # Stage A only + python3 scripts/train_hierarchical.py --stage disease # Stage B only (requires Stage A checkpoint) + python3 scripts/train_hierarchical.py --resume path.ckpt # Resume from checkpoint + +Requirements: + pip install torch torchvision timm pytorch-lightning albumentations wandb +""" + +import argparse +import json +import os +import random +import sys +from collections import defaultdict +from pathlib import Path + +import albumentations as A +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from albumentations.pytorch import ToTensorV2 +from PIL import Image +from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler +from tqdm import tqdm + +# Conditionally import pytorch-lightning (optional — for multi-GPU) +try: + import pytorch_lightning as pl + from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + from pytorch_lightning.loggers import WandbLogger + HAS_LIGHTNING = True +except ImportError: + HAS_LIGHTNING = False + +# Conditionally import wandb +try: + import wandb + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + +# ─── Config ─────────────────────────────────────────────────────────────────── + +BASE_DIR = Path(__file__).resolve().parent.parent +DATA_DIR = BASE_DIR / "data" / "organized" +TRAIN_DIR = DATA_DIR / "train" +VAL_DIR = DATA_DIR / "val" +CHECKPOINT_DIR = BASE_DIR / "checkpoints" +LOG_DIR = BASE_DIR / "logs" + +# ImageNet normalization constants +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +# Default training config +DEFAULT_CONFIG = { + # Model + "backbone": "swin_tiny_patch4_window7_224", + "img_size": 224, + "dropout": 0.1, + + # Stage A (Species) + "species_epochs_warmup": 5, + "species_epochs_finetune": 20, + "species_epochs_final": 5, + "species_lr_warmup": 1e-3, + "species_lr_finetune": 1e-4, + "species_lr_final": 5e-6, + "species_batch_size": 512, + + # Stage B (Disease) + "disease_epochs_heads": 15, + "disease_epochs_rare": 5, + "disease_epochs_e2e": 10, + "disease_lr_heads": 1e-3, + "disease_lr_rare": 5e-4, + "disease_lr_e2e": 1e-5, + "disease_loss_weight": 0.7, + + # Data + "num_workers": 8, + "prefetch_factor": 4, + "val_split_pct": 0.15, + + # Augmentation + "mixup_alpha": 0.2, + "cutmix_alpha": 1.0, + "label_smoothing": 0.1, + + # Training + "gradient_clip_val": 1.0, + "weight_decay": 0.05, + "warmup_epochs": 3, + + # System + "seed": 42, + "device": "cuda" if torch.cuda.is_available() else "cpu", + "precision": "16-mixed", # mixed precision training + "accumulate_grad_batches": 1, +} + + +# ─── Dataset ────────────────────────────────────────────────────────────────── + +class HierarchicalDataset(Dataset): + """ + Reads organized dataset from data/organized/{train,val}/{species}/{disease}/. + + Loads species_index.json to build class mappings, then walks directories. + """ + + def __init__(self, data_dir: Path, split: str, transform=None, config=None): + self.data_dir = data_dir / split + self.split = split + self.transform = transform + self.config = config or DEFAULT_CONFIG + + # Load metadata + self.species_index = self._load_json(data_dir / "species_index.json") + self.dataset_stats = self._load_json(data_dir / "dataset_stats.json") + + # Build mappings + self.species_list = sorted(self.species_index.keys()) + self.species_to_idx = {s: i for i, s in enumerate(self.species_list)} + self.num_species = len(self.species_list) + + # Build disease mappings per species + self.disease_to_idx = {} # species → {disease_name → index} + self.disease_counts = {} # species → {disease_name → count} + self.samples = [] # List of (image_path, species_idx, disease_name) + + self._build_index() + + def _load_json(self, path): + with open(path) as f: + return json.load(f) + + def _build_index(self): + """Walk the data directory and build sample index.""" + species_dirs = sorted([ + d for d in os.listdir(self.data_dir) + if os.path.isdir(self.data_dir / d) and not d.startswith(".") + ]) + + # Add any species from species_index that exist as directories + for species in species_dirs: + if species not in self.species_to_idx: + print(f" [WARN] Species '{species}' not in species_index, skipping") + continue + + species_idx = self.species_to_idx[species] + disease_dirs = sorted([ + d for d in os.listdir(self.data_dir / species) + if os.path.isdir(self.data_dir / species / d) and not d.startswith(".") + ]) + + for disease in disease_dirs: + if disease not in self.disease_to_idx.get(species, {}): + self.disease_to_idx.setdefault(species, {})[disease] = len( + self.disease_to_idx.get(species, {}) + ) + + disease_dir = self.data_dir / species / disease + image_files = sorted([ + f for f in os.listdir(disease_dir) + if os.path.isfile(disease_dir / f) + and f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")) + ]) + + self.disease_counts.setdefault(species, {}).setdefault(disease, 0) + self.disease_counts[species][disease] += len(image_files) + + for img_file in image_files: + self.samples.append(( + disease_dir / img_file, + species_idx, + disease, + species, + )) + + print(f" [{self.split}] {len(self.samples)} images, " + f"{self.num_species} species") + + def get_species_class_distribution(self): + """Return list of sample counts per species.""" + counts = [0] * self.num_species + for _, sp_idx, _, _ in self.samples: + counts[sp_idx] += 1 + return counts + + def get_disease_class_weights(self, species): + """Return per-sample weights for disease class balancing.""" + counts = self.disease_counts.get(species, {}) + if not counts: + return None + total = sum(counts.values()) + weights = {} + for disease, count in counts.items(): + # Weight inversely proportional to sqrt(count) — moderate upweighting + weights[disease] = total / (len(counts) * np.sqrt(count) * np.sqrt(count / total)) + return weights + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + img_path, species_idx, disease_name, species = self.samples[idx] + + # Load image + image = Image.open(img_path).convert("RGB") + image_np = np.array(image) + + # Apply transforms + if self.transform: + augmented = self.transform(image=image_np) + image_tensor = augmented["image"] + else: + # Default: resize + normalize + import torchvision.transforms.functional as TF + from torchvision.transforms import InterpolationMode + image = image.resize((224, 224), Image.BILINEAR) + image_tensor = TF.to_tensor(image) + image_tensor = TF.normalize(image_tensor, IMAGENET_MEAN, IMAGENET_STD) + + # Disease label + disease_idx = self.disease_to_idx[species].get(disease_name, 0) + num_diseases = len(self.disease_to_idx[species]) + + return { + "image": image_tensor, + "species_idx": torch.tensor(species_idx, dtype=torch.long), + "disease_idx": torch.tensor(disease_idx, dtype=torch.long), + "disease_name": disease_name, + "species": species, + "num_diseases": num_diseases, + } + + +# ─── Augmentations ──────────────────────────────────────────────────────────── + +def get_train_transform(config=None): + """Get training augmentation pipeline (Tier 1 + 2).""" + config = config or DEFAULT_CONFIG + img_size = config.get("img_size", 224) + + return A.Compose([ + # Tier 1 — Core geometric & photometric + A.RandomResizedCrop(size=(img_size, img_size), scale=(0.6, 1.0), ratio=(0.75, 1.33)), + A.HorizontalFlip(p=0.5), + A.Rotate(limit=45, p=0.5), + A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8), + + # Tier 2 — Degradation & quality simulation + A.OneOf([ + A.GaussianBlur(blur_limit=(3, 7), p=1.0), + A.GaussNoise(std_range=(0.01, 0.05), p=1.0), + A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0), + ], p=0.3), + A.ImageCompression(quality_range=(60, 95), p=0.2), + A.RandomShadow(p=0.15), + A.ToGray(p=0.05), + + # Normalize + A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ToTensorV2(), + ]) + + +def get_val_transform(config=None): + """Get validation transform (minimal, deterministic).""" + config = config or DEFAULT_CONFIG + img_size = config.get("img_size", 224) + + return A.Compose([ + A.Resize(img_size, img_size), + A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ToTensorV2(), + ]) + + +# ─── Model ──────────────────────────────────────────────────────────────────── + +class HierarchicalSwin(nn.Module): + """ + Hierarchical Swin-Tiny model with species and disease heads. + + Architecture: + Input (224×224×3) + → Swin-Tiny backbone (timm, 768-dim features) + → Species head (Linear: 768 → num_species) + → Disease heads (one Linear: 768 → num_diseases_per_species) + + Inference: one forward pass through backbone, then route 768-dim feature + to correct disease head based on predicted species. + """ + + def __init__(self, config, species_index, disease_to_idx): + super().__init__() + self.config = config + self.species_list = sorted(species_index.keys()) + self.num_species = len(self.species_list) + self.disease_to_idx = disease_to_idx + + # Backbone + import timm + self.backbone = timm.create_model( + config["backbone"], + pretrained=True, + num_classes=0, # No classification head — we use features only + global_pool="avg", + ) + self.feature_dim = self.backbone.num_features # 768 for Swin-Tiny + + # Species head + self.species_head = nn.Sequential( + nn.Dropout(config.get("dropout", 0.1)), + nn.Linear(self.feature_dim, self.num_species), + ) + + # Disease heads — one per species, built dynamically + self.disease_heads = nn.ModuleDict() + for species in self.species_list: + num_diseases = len(species_index[species]) + if num_diseases > 1: + self.disease_heads[species] = nn.Linear(self.feature_dim, num_diseases) + + # Store num_diseases per species for routing + self.num_diseases_per_species = { + species: max(v.values()) + 1 for species, v in disease_to_idx.items() + } if disease_to_idx else {} + + print(f" Model: Swin-Tiny backbone ({self.feature_dim}-dim features)") + print(f" Species head: {self.num_species} classes") + print(f" Disease heads: {len(self.disease_heads)} species-specific heads") + + def forward(self, x): + """Forward pass: return features, species_logits, and route to disease heads.""" + features = self.backbone(x) # [B, 768] + species_logits = self.species_head(features) # [B, num_species] + return features, species_logits + + def forward_disease(self, features, species_labels=None, species_logits=None): + """ + Disease forward pass using species-specific heads. + + Args: + features: [B, 768] backbone features + species_labels: [B] ground truth species indices (for training) + species_logits: [B, num_species] predicted species scores (for inference) + + Returns: + disease_logits_dict: dict mapping species name → disease logits for that sample + """ + if species_labels is not None: + # Training: use ground-truth species for routing + species_names = [self.species_list[i] for i in species_labels] + elif species_logits is not None: + # Inference: use predicted species + pred_species = species_logits.argmax(dim=1) + species_names = [self.species_list[i] for i in pred_species] + else: + raise ValueError("Need either species_labels (train) or species_logits (inference)") + + batch_size = features.shape[0] + disease_logits = torch.zeros(batch_size, 1, device=features.device) + + for i, species_name in enumerate(species_names): + if species_name in self.disease_heads: + head = self.disease_heads[species_name] + logits = head(features[i:i+1]) # [1, num_diseases] + if i == 0: + disease_logits = logits + else: + # Pad to match max num diseases across batch + pad_size = logits.shape[1] - disease_logits.shape[1] + if pad_size > 0: + disease_logits = F.pad(disease_logits, (0, pad_size)) + elif pad_size < 0: + logits = F.pad(logits, (0, -pad_size)) + disease_logits = torch.cat([disease_logits, logits], dim=0) + else: + # Species with only 1 disease (e.g. "healthy" only) + disease_logits = F.pad( + disease_logits if i > 0 else disease_logits[:0], + (0, 0, 0, 1) + ) + + return disease_logits, species_names + + def get_trainable_params(self, stage="species_head"): + """Get parameter groups with appropriate learning rates.""" + if stage == "species_head": + # Only train species head, freeze backbone + return [ + {"params": self.species_head.parameters(), "lr": self.config["species_lr_warmup"]}, + ] + elif stage == "species_full": + # Train all species-related params with discriminative LR + backbone_params = [] + head_params = [] + for name, param in self.backbone.named_parameters(): + backbone_params.append(param) + for name, param in self.species_head.named_parameters(): + head_params.append(param) + + return [ + {"params": backbone_params, "lr": self.config["species_lr_finetune"] * 0.1}, + {"params": head_params, "lr": self.config["species_lr_finetune"]}, + ] + elif stage == "disease_heads": + # Freeze backbone, train disease heads + return [ + {"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_heads"]}, + ] + elif stage == "disease_rare": + # Freeze backbone, train disease heads with rare-class focus + return [ + {"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_rare"]}, + ] + elif stage == "e2e": + # End-to-end: unfreeze everything + backbone_params = [] + head_params = [] + disease_params = [] + for name, param in self.backbone.named_parameters(): + backbone_params.append(param) + for name, param in self.species_head.named_parameters(): + head_params.append(param) + for name, param in self.disease_heads.named_parameters(): + disease_params.append(param) + + return [ + {"params": backbone_params, "lr": self.config["disease_lr_e2e"] * 0.1}, + {"params": head_params, "lr": self.config["disease_lr_e2e"]}, + {"params": disease_params, "lr": self.config["disease_lr_e2e"]}, + ] + + +# ─── MixUp / CutMix ────────────────────────────────────────────────────────── + +def mixup_data(x, y_species, y_disease, alpha=0.2): + """Apply MixUp augmentation to a batch.""" + if alpha == 0: + return x, y_species, y_disease, None, None, None + + lam = np.random.beta(alpha, alpha) + batch_size = x.size(0) + index = torch.randperm(batch_size).to(x.device) + + mixed_x = lam * x + (1 - lam) * x[index] + y_species_a, y_species_b = y_species, y_species[index] + y_disease_a, y_disease_b = y_disease, y_disease[index] + + return mixed_x, y_species_a, y_species_b, y_disease_a, y_disease_b, lam + + +def mixup_criterion(criterion, pred, y_a, y_b, lam): + """MixUp loss: linear blend of losses between two label distributions.""" + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + + +# ─── Training Loop ──────────────────────────────────────────────────────────── + +class HierarchicalTrainer: + """Custom training loop for hierarchical model (no Lightning dependency).""" + + def __init__(self, config=None): + self.config = config or DEFAULT_CONFIG.copy() + self.device = torch.device(self.config["device"]) + self.seed_everything() + + # Setup logging + self.log_file = LOG_DIR / "training_log.json" + LOG_DIR.mkdir(parents=True, exist_ok=True) + self.metrics_history = [] + + def seed_everything(self): + random.seed(self.config["seed"]) + np.random.seed(self.config["seed"]) + torch.manual_seed(self.config["seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(self.config["seed"]) + + def prepare_data(self): + """Load datasets and create dataloaders.""" + print("\n" + "=" * 60) + print("Preparing data...") + print("=" * 60) + + self.train_transform = get_train_transform(self.config) + self.val_transform = get_val_transform(self.config) + + self.train_dataset = HierarchicalDataset( + DATA_DIR, "train", transform=self.train_transform, config=self.config + ) + self.val_dataset = HierarchicalDataset( + DATA_DIR, "val", transform=self.val_transform, config=self.config + ) + + self.species_index = self.train_dataset.species_index + self.disease_to_idx = self.train_dataset.disease_to_idx + + # Training dataloader with optional class-balanced sampling + species_counts = self.train_dataset.get_species_class_distribution() + # Use standard shuffle for now; weighted sampling can be added for disease heads + + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.config["species_batch_size"], + shuffle=True, + num_workers=self.config["num_workers"], + prefetch_factor=self.config["prefetch_factor"], + pin_memory=True, + persistent_workers=True, + drop_last=True, + ) + + self.val_loader = DataLoader( + self.val_dataset, + batch_size=self.config["species_batch_size"] * 2, + shuffle=False, + num_workers=self.config["num_workers"], + pin_memory=True, + persistent_workers=True, + ) + + print(f" Train: {len(self.train_dataset)} images, {len(self.train_loader)} batches") + print(f" Val: {len(self.val_dataset)} images, {len(self.val_loader)} batches") + + def create_model(self): + """Create the hierarchical model.""" + print("\n" + "=" * 60) + print("Building model...") + print("=" * 60) + + self.model = HierarchicalSwin( + self.config, + self.species_index, + self.disease_to_idx, + ).to(self.device) + + self.species_criterion = nn.CrossEntropyLoss( + label_smoothing=self.config["label_smoothing"] + ) + self.disease_criterion = nn.CrossEntropyLoss() + + def save_checkpoint(self, path, stage, epoch, metrics): + """Save model checkpoint.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + checkpoint = { + "stage": stage, + "epoch": epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict() if hasattr(self, 'optimizer') else None, + "scheduler_state_dict": self.scheduler.state_dict() if hasattr(self, 'scheduler') else None, + "metrics": metrics, + "config": self.config, + "species_index": self.species_index, + "disease_to_idx": self.disease_to_idx, + } + torch.save(checkpoint, path) + print(f" ✓ Checkpoint saved: {path}") + + def load_checkpoint(self, path): + """Load model from checkpoint.""" + checkpoint = torch.load(path, map_location=self.device, weights_only=False) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.species_index = checkpoint.get("species_index", self.species_index) + self.disease_to_idx = checkpoint.get("disease_to_idx", self.disease_to_idx) + print(f" ✓ Loaded checkpoint: {path} (stage={checkpoint.get('stage')}, " + f"epoch={checkpoint.get('epoch')})") + return checkpoint + + def log_metrics(self, stage, epoch, metrics): + """Log metrics to file and console.""" + metrics["stage"] = stage + metrics["epoch"] = epoch + self.metrics_history.append(metrics) + + # Write to log file + with open(self.log_file, "w") as f: + json.dump(self.metrics_history, f, indent=2) + + # Print + print(f" [{stage}] Epoch {epoch}: " + f"loss={metrics.get('train_loss', 0):.4f}, " + f"val_loss={metrics.get('val_loss', 0):.4f}, " + f"species_acc={metrics.get('val_species_acc', 0):.2%}") + + # Wandb logging + if HAS_WANDB and wandb.run is not None: + wandb.log({f"{stage}/{k}": v for k, v in metrics.items()}, step=epoch) + + # ─── Stage A: Species Classifier ─────────────────────────────────────── + + def train_species_stage(self, resume_from=None): + """ + Stage A: Train species classifier. + + Sub-stages: + 1. Head warmup (backbone frozen) + 2. Full fine-tune (unfrozen backbone, cosine LR) + 3. Final fine-tune (discriminative LR) + """ + print("\n" + "=" * 60) + print("Stage A — Species Classifier Training") + print("=" * 60) + + if resume_from: + checkpoint = self.load_checkpoint(resume_from) + start_epoch = checkpoint.get("epoch", 0) + 1 + else: + start_epoch = 0 + + # ── Sub-stage A1: Head warmup ── + if start_epoch <= self.config["species_epochs_warmup"]: + print("\n── Sub-stage A1: Head warmup (backbone frozen) ──") + params = self.model.get_trainable_params("species_head") + self._run_training_stage( + "species_warmup", + params, + self.config["species_lr_warmup"], + self.config["species_epochs_warmup"], + start_epoch, + freeze_backbone=True, + ) + + # ── Sub-stage A2: Full fine-tune ── + a2_start = self.config["species_epochs_warmup"] + a2_end = a2_start + self.config["species_epochs_finetune"] + if start_epoch <= a2_end: + print("\n── Sub-stage A2: Full fine-tune (cosine LR) ──") + params = self.model.get_trainable_params("species_full") + self._run_training_stage( + "species_finetune", + params, + self.config["species_lr_finetune"], + a2_end, + max(start_epoch, a2_start), + freeze_backbone=False, + cosine_schedule=True, + warmup_epochs=self.config["warmup_epochs"], + ) + + # ── Sub-stage A3: Final fine-tune ── + a3_start = a2_end + a3_end = a3_start + self.config["species_epochs_final"] + if start_epoch <= a3_end: + print("\n── Sub-stage A3: Final fine-tune (discriminative LR) ──") + params = self.model.get_trainable_params("species_full") + self._run_training_stage( + "species_final", + params, + self.config["species_lr_final"], + a3_end, + max(start_epoch, a3_start), + freeze_backbone=False, + ) + + print("\n ✓ Stage A complete!") + + # ─── Stage B: Disease Classifiers ────────────────────────────────────── + + def train_disease_stage(self, resume_from=None): + """ + Stage B: Train disease-specific heads. + + Sub-stages: + 1. All disease heads (backbone frozen) + 2. Rare-class boost (oversampled rare classes) + 3. End-to-end (unfrozen backbone, joint loss) + """ + print("\n" + "=" * 60) + print("Stage B — Disease Classifier Training") + print("=" * 60) + + if resume_from: + self.load_checkpoint(resume_from) + + # ── Sub-stage B1: Disease heads ── + print("\n── Sub-stage B1: Disease heads (backbone frozen) ──") + params = self.model.get_trainable_params("disease_heads") + self._run_training_stage( + "disease_heads", + params, + self.config["disease_lr_heads"], + self.config["disease_epochs_heads"], + freeze_backbone=True, + use_disease_loss=True, + ) + + # ── Sub-stage B2: Rare-class boost ── + print("\n── Sub-stage B2: Rare-class boost ──") + params = self.model.get_trainable_params("disease_rare") + self._run_training_stage( + "disease_rare", + params, + self.config["disease_lr_rare"], + self.config["disease_epochs_rare"], + freeze_backbone=True, + use_disease_loss=True, + oversample_rare=True, + ) + + # ── Sub-stage B3: End-to-end ── + print("\n── Sub-stage B3: End-to-end joint training ──") + params = self.model.get_trainable_params("e2e") + self._run_training_stage( + "disease_e2e", + params, + self.config["disease_lr_e2e"], + self.config["disease_epochs_e2e"], + freeze_backbone=False, + use_disease_loss=True, + ) + + print("\n ✓ Stage B complete!") + + # ─── Core Training Loop ──────────────────────────────────────────────── + + def _run_training_stage(self, stage_name, param_groups, base_lr, num_epochs, + start_epoch=0, freeze_backbone=True, + cosine_schedule=False, warmup_epochs=0, + use_disease_loss=False, oversample_rare=False): + """Run a training stage with the given parameters.""" + + # Freeze/unfreeze backbone + for param in self.model.backbone.parameters(): + param.requires_grad = not freeze_backbone + + # Setup optimizer + self.optimizer = torch.optim.AdamW( + param_groups, + weight_decay=self.config["weight_decay"], + ) + + # Setup scheduler + if cosine_schedule: + total_steps = num_epochs * len(self.train_loader) + warmup_steps = warmup_epochs * len(self.train_loader) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=total_steps - warmup_steps + ) + # Wrap with linear warmup + from torch.optim.lr_scheduler import SequentialLR, LinearLR + warmup_scheduler = LinearLR( + self.optimizer, start_factor=0.1, total_iters=warmup_steps + ) + self.scheduler = SequentialLR( + self.optimizer, + schedulers=[warmup_scheduler, self.scheduler], + milestones=[warmup_steps], + ) + else: + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, mode="min", factor=0.5, patience=3 + ) + + # Default dataloader + train_loader = self.train_loader + + # Create oversampled dataloader for rare classes if requested + if oversample_rare: + # Add per-class weighting for rare disease classes + weights = self._compute_disease_sample_weights() + if weights is not None: + sampler = WeightedRandomSampler( + weights, len(weights), replacement=True + ) + train_loader = DataLoader( + self.train_dataset, + batch_size=self.config["species_batch_size"], + sampler=sampler, + num_workers=self.config["num_workers"], + prefetch_factor=self.config["prefetch_factor"], + pin_memory=True, + persistent_workers=True, + drop_last=True, + ) + print(f" Using weighted sampler for rare-class boost") + + # Training loop + scaler = torch.amp.GradScaler(enabled=(self.config["precision"] == "16-mixed")) + + for epoch in range(start_epoch, num_epochs): + # Train one epoch + train_loss = self._train_epoch( + train_loader, use_disease_loss, scaler, stage_name + ) + + # Validate + val_metrics = self._validate(use_disease_loss) + + # Step scheduler + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(val_metrics["val_loss"]) + else: + self.scheduler.step() + + # Log + metrics = {"train_loss": train_loss, **val_metrics} + self.log_metrics(stage_name, epoch, metrics) + + # Save checkpoint every 5 epochs + if (epoch + 1) % 5 == 0 or (epoch + 1) == num_epochs: + ckpt_name = f"{stage_name}_epoch={epoch:02d}.pt" + self.save_checkpoint( + CHECKPOINT_DIR / stage_name / ckpt_name, + stage_name, epoch, metrics + ) + + def _train_epoch(self, loader, use_disease_loss, scaler, stage_name): + """Train for one epoch.""" + self.model.train() + total_loss = 0 + num_batches = 0 + + pbar = tqdm(loader, desc=f" {stage_name}", leave=False) + for batch_idx, batch in enumerate(pbar): + images = batch["image"].to(self.device) + species_labels = batch["species_idx"].to(self.device) + + with torch.amp.autocast( + device_type=self.device.type, + enabled=(self.config["precision"] == "16-mixed") + ): + # Forward + features, species_logits = self.model(images) + loss_species = self.species_criterion(species_logits, species_labels) + + if use_disease_loss: + disease_logits, _ = self.model.forward_disease( + features, species_labels=species_labels + ) + disease_labels = batch["disease_idx"].to(self.device) + loss_disease = self.disease_criterion(disease_logits, disease_labels) + loss = loss_species + self.config["disease_loss_weight"] * loss_disease + else: + loss = loss_species + + # Backward + self.optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.config["gradient_clip_val"] + ) + scaler.step(self.optimizer) + scaler.update() + + total_loss += loss.item() + num_batches += 1 + + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + return total_loss / num_batches + + @torch.no_grad() + def _validate(self, use_disease_loss=False): + """Run validation.""" + self.model.eval() + total_loss = 0 + correct_species = 0 + total_samples = 0 + correct_disease = 0 + + for batch in tqdm(self.val_loader, desc=" val", leave=False): + images = batch["image"].to(self.device) + species_labels = batch["species_idx"].to(self.device) + + features, species_logits = self.model(images) + + # Species accuracy + species_preds = species_logits.argmax(dim=1) + correct_species += (species_preds == species_labels).sum().item() + + # Species loss + loss_species = self.species_criterion(species_logits, species_labels) + total_loss += loss_species.item() + + # Disease accuracy (if using disease heads) + if use_disease_loss: + disease_logits, _ = self.model.forward_disease( + features, species_labels=species_labels + ) + disease_labels = batch["disease_idx"].to(self.device) + + # Pad disease labels to match logits width + if disease_logits.shape[1] > 1: + loss_disease = self.disease_criterion(disease_logits, disease_labels) + total_loss += self.config["disease_loss_weight"] * loss_disease.item() + + disease_preds = disease_logits.argmax(dim=1) + disease_labels_clipped = disease_labels.clamp(0, disease_logits.shape[1] - 1) + correct_disease += (disease_preds == disease_labels_clipped).sum().item() + + total_samples += images.size(0) + + metrics = { + "val_loss": total_loss / max(len(self.val_loader), 1), + "val_species_acc": correct_species / max(total_samples, 1), + } + + if use_disease_loss: + metrics["val_disease_acc"] = correct_disease / max(total_samples, 1) + + return metrics + + def _compute_disease_sample_weights(self): + """Compute sample weights for rare-class oversampling.""" + weights = [] + for species in self.train_dataset.species_list: + class_weights = self.train_dataset.get_disease_class_weights(species) + if class_weights is None: + continue + # Map weights to each sample + for idx, (_, sp_idx, disease_name, sp) in enumerate(self.train_dataset.samples): + if sp == species: + w = class_weights.get(disease_name, 1.0) + weights.append(w) + + return torch.tensor(weights, dtype=torch.float) if weights else None + + +# ─── Main ──────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="Hierarchical Swin-Tiny Training") + parser.add_argument("--stage", choices=["species", "disease", "all"], default="all", + help="Training stage to run (default: all)") + parser.add_argument("--resume", type=str, default=None, + help="Resume from checkpoint path") + parser.add_argument("--batch-size", type=int, default=None, + help="Override batch size") + parser.add_argument("--epochs", type=int, default=None, + help="Override number of epochs (overrides all stage epochs)") + parser.add_argument("--device", type=str, default=None, + help="Override device (cuda, cpu)") + parser.add_argument("--no-wandb", action="store_true", + help="Disable wandb logging") + parser.add_argument("--dry-run", action="store_true", + help="Load data and model, run 1 batch, then exit") + args = parser.parse_args() + + # Setup config + config = DEFAULT_CONFIG.copy() + if args.batch_size: + config["species_batch_size"] = args.batch_size + if args.device: + config["device"] = args.device + if args.epochs: + for key in ["species_epochs_warmup", "species_epochs_finetune", + "species_epochs_final", "disease_epochs_heads", + "disease_epochs_rare", "disease_epochs_e2e"]: + config[key] = args.epochs + + # Fix device name for ROCm + if config["device"] == "cuda" and torch.cuda.is_available(): + print(f" GPU: {torch.cuda.get_device_name(0)}") + print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") + if hasattr(torch.version, "hip") and torch.version.hip: + print(f" ROCm version: {torch.version.hip}") + + # Create trainer + trainer = HierarchicalTrainer(config) + + # Prepare data + trainer.prepare_data() + + # Create or load model + if args.resume: + trainer.create_model() + trainer.load_checkpoint(args.resume) + else: + trainer.create_model() + + # Dry run + if args.dry_run: + print("\n── Dry run: 1 batch ──") + batch = next(iter(trainer.train_loader)) + images = batch["image"].to(trainer.device) + species_labels = batch["species_idx"].to(trainer.device) + features, species_logits = trainer.model(images) + loss = trainer.species_criterion(species_logits, species_labels) + loss.backward() + print(f" Forward + backward: ✓ (loss={loss.item():.4f})") + print(f" Features shape: {features.shape}") + print(f" Species logits shape: {species_logits.shape}") + print(" Dry run complete!") + return + + # Initialize wandb + if HAS_WANDB and not args.no_wandb: + wandb.init( + project="plant-disease-hierarchical", + config=config, + ) + + # Run stages + if args.stage in ("species", "all"): + trainer.train_species_stage(resume_from=args.resume if args.stage == "species" else None) + + if args.stage in ("disease", "all"): + # For 'all', continue from species checkpoint + species_checkpoint = CHECKPOINT_DIR / "species_final" / "species_final_final.pt" + trainer.train_disease_stage( + resume_from=str(species_checkpoint) if species_checkpoint.exists() else None + ) + + # Close wandb + if HAS_WANDB and wandb.run is not None: + wandb.finish() + + print("\n" + "=" * 60) + print("Training complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tasks/multi-image-user-feedback/01-api-types-and-schema.md b/tasks/multi-image-user-feedback/01-api-types-and-schema.md new file mode 100644 index 0000000..49feeb1 --- /dev/null +++ b/tasks/multi-image-user-feedback/01-api-types-and-schema.md @@ -0,0 +1,68 @@ +# 01. Extend Types and Add Feedback DB Schema + +meta: +id: multi-image-user-feedback-01 +feature: multi-image-user-feedback +priority: P0 +depends_on: [] +tags: [types, schema, database] + +objective: + +- Update shared TypeScript types to support multi-image requests, species-constrained inference, top-5 combo predictions, and post-evaluation feedback. +- Add a new database table for storing feedback entries. + +deliverables: + +- `src/lib/types.ts` — updated with new interfaces +- `src/lib/db/schema.ts` — updated with `diagnosisFeedback` table + +steps: + +1. Add these new types to `src/lib/types.ts`: + - `IdentifyOptions` — optional fields sent in the identify request: `secondImageId?: string`, `userSpecies?: string`, `useForTraining?: boolean` + - `IdentifyRequest` — extend to include `options?: IdentifyOptions` + - `TopPrediction` — a prediction with both species and disease info: `{ speciesName: string, diseaseName: string, diseaseId: string, confidence: ConfidenceResult, rank: number }` + - `IdentifyResponse` — extend to include `topSpeciesDisease?: TopPrediction[]`, `speciesConfidence?: ConfidenceResult`, `infoProvided: string[]` (which optional inputs the user gave) + - `AccuracyRating` — `"correct" | "incorrect" | "unsure"` + - `DiagnosisFeedback` — full feedback shape: `{ sessionId: string, imageIds: string[], userSpecies?: string, predictedDiseaseId: string, accuracyRating: AccuracyRating, consentToStoreImages: boolean, userCorrectedSpecies?: string, notes?: string, createdAt: string }` + - `FeedbackRequest` — POST body for the feedback endpoint + - `FeedbackResponse` — confirmation response + +2. Add a `diagnosisFeedback` table to `src/lib/db/schema.ts`: + - `id` — text primary key (UUID v4) + - `sessionId` — text, session identifier for grouping + - `imageIds` — JSON text array of stored image IDs + - `userSpecies` — text, nullable + - `predictedDiseaseId` — text, the top model prediction + - `accuracyRating` — text enum: `"correct" | "incorrect" | "unsure"` + - `consentToStoreImages` — integer (boolean) + - `userCorrectedSpecies` — text, nullable (only when accuracy=incorrect or unsure) + - `notes` — text, nullable + - `modelVersion` — text, the model version used + - `createdAt` — text, auto timestamp + - Add indexes on `sessionId`, `accuracyRating`, `createdAt` + +3. Export `DiagnosisFeedbackRow` and `DiagnosisFeedbackInsert` type helpers. + +tests: + +- Unit: verify new types compile correctly +- Unit: verify schema migration produces correct table DDL +- Unit: verify INSERT and SELECT on feedback table through Drizzle + +acceptance_criteria: + +- All new types are exported from `src/lib/types.ts` +- `diagnosisFeedback` table exists in schema with all required columns +- `DiagnosisFeedbackRow` and `DiagnosisFeedbackInsert` are exported + +validation: + +- `npx tsc --noEmit` passes +- Drizzle Kit (`npx drizzle-kit generate`) produces valid migration SQL + +notes: + +- The `sessionId` ties together the upload, identify, and feedback flow +- Image storage consent is a boolean to comply with data privacy requirements diff --git a/tasks/multi-image-user-feedback/02-multi-image-inference-pipeline.md b/tasks/multi-image-user-feedback/02-multi-image-inference-pipeline.md new file mode 100644 index 0000000..b4c0780 --- /dev/null +++ b/tasks/multi-image-user-feedback/02-multi-image-inference-pipeline.md @@ -0,0 +1,65 @@ +# 02. Multi-Image Ensemble & Species-Constrained Inference + +meta: +id: multi-image-user-feedback-02 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01] +tags: [inference, ml] + +objective: + +- Extend the inference pipeline to support multi-image ensemble inference (averaging features or logits from 2+ images). +- Add species-constrained softmax that renormalizes probabilities over only the disease classes belonging to a known species. + +deliverables: + +- `src/lib/ml/inference.ts` — updated with ensemble and constrained inference functions +- `src/lib/ml/confidence.ts` — updated with species-aware confidence calibration + +steps: + +1. In `src/lib/ml/inference.ts`, add: + - `runEnsembleInference(tensors: Float32Array[], topK?: number): Promise` — runs multiple images through the model, averages their logits, and returns top-K predictions. Averaging logits (before softmax) is preferred over averaging probabilities since it preserves confidence structure. + - `speciesConstrainedSoftmax(logits: Float32Array, speciesClasses: number[]): Float32Array` — given the full 11,818-class logits and a list of class indices belonging to the user-specified species, compute softmax over only those indices and return a renormalized probability vector (zero everywhere else). The model output dimension (11,818) should be a configurable constant. + - `runSpeciesConstrainedInference(tensor: Float32Array, speciesClassIndices: number[], topK?: number): Promise` — run inference then apply species-constrained softmax before extracting top-K. + - `runEnsembleSpeciesConstrained(tensors: Float32Array[], speciesClassIndices: number[], topK?: number): Promise` — ensemble then constrain. + +2. Export `CLASSIFIER_NUM_CLASSES` constant (11,818) and `SPECIES_CLASS_RANGES` (a map from species name → [startIndex, endIndex] in the model output) from a new constants file or from labels.ts. + +3. In `src/lib/ml/confidence.ts`, add: + - `calibrateSpeciesConfidence(rawProb: number, numDiseaseClasses: number): ConfidenceResult` — adjusts calibration factor based on how many disease classes the species has (fewer classes = higher effective confidence). + - `getEnsembleConfidence(predictions: RawPrediction[][]): ConfidenceResult` — aggregate confidence from multiple images. + +4. Create `src/lib/ml/species-class-ranges.ts` containing the mapping from species name → [class start index, class end index] in the 11,818-class model output. This is derived from the training dataset's `species_index.json` or `class_hierarchy.json`. + +5. Handle edge cases: + - If tensors array is empty → throw + - If tensor length doesn't match expected model input → throw validation error + - If species name not found in `SPECIES_CLASS_RANGES` → fall back to full softmax + +tests: + +- Unit: test logit averaging with 2 identical tensors → results should be identical to single inference +- Unit: test logit averaging with 2 different tensors → verify averaged output +- Unit: test species-constrained softmax — verify probabilities are zero outside the constrained indices +- Unit: test constrained softmax sums to ~1.0 within the species class range +- Unit: test ensemble + constrained combined pipeline + +acceptance_criteria: + +- `runEnsembleInference` accepts multiple tensors and returns averaged top-K predictions +- `speciesConstrainedSoftmax` zeros out all classes outside the species range +- `runSpeciesConstrainedInference` and `runEnsembleSpeciesConstrained` produce constrained results +- Confidence calibration accounts for number of disease classes in the species + +validation: + +- `npx tsc --noEmit` passes +- Unit tests pass with `npx vitest run src/lib/ml/ --reporter=verbose` + +notes: + +- The current mock model outputs 38 classes. These new functions target the 11,818-class model. +- Until the real model loads, ensemble/constrained functions should still work with mock data (just with fewer classes). +- The species-ranges file should be auto-generated from `data/organized/class_hierarchy.json` and checked into version control. diff --git a/tasks/multi-image-user-feedback/03-hierarchical-model-loader.md b/tasks/multi-image-user-feedback/03-hierarchical-model-loader.md new file mode 100644 index 0000000..7684ba6 --- /dev/null +++ b/tasks/multi-image-user-feedback/03-hierarchical-model-loader.md @@ -0,0 +1,73 @@ +# 03. Load Trained Swin-Tiny Model with Species/Disease Routing + +meta: +id: multi-image-user-feedback-03 +feature: multi-image-user-feedback +priority: P2 +depends_on: [] +tags: [ml, model-loader, inference] + +objective: + +- Create a new model loader backend that loads the trained Swin-Tiny checkpoint (`species_final_final.pt`) and routes through the species head and disease heads to produce 11,818-class logits. +- This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation. + +deliverables: + +- `src/lib/ml/hierarchical-model.ts` — new PlantDiseaseModel implementation for the Swin-Tiny model +- `scripts/export-model.js` — script to export the PyTorch checkpoint to TF.js format +- `public/models/plant-disease-classifier-v2/` — exported model directory (TF.js or ONNX) + +steps: + +1. Create `scripts/export-model.js`: + - Load the PyTorch checkpoint from `checkpoints/species_final/species_final_final.pt` + - Export to ONNX format with NCHW input shape [1, 3, 224, 224] + - Also export `species_index.json` and `class_hierarchy.json` alongside the model + - Output to `public/models/plant-disease-classifier-v2/` + +2. Create `src/lib/ml/hierarchical-model.ts`: + - Implement the `PlantDiseaseModel` interface + - Load the ONNX model via `onnxruntime-node` + - Load species/disease index files + - Implement `predict()`: + - Preprocess to 224×224 (Swin-Tiny input size, not 160) + - Run forward pass → get [1, 768] features → species logits → disease routing + - The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads + - Return the full 11,818-dimension logits array + - Implement `getStatus()` returning model metadata with `numClasses: 11818` + +3. Update `src/lib/ml/model-loader.ts`: + - Add detection for v2 model directory (`model-v2.json` or similar) + - Try loading v2 model first (if available), fall back to v1 then mock + - Export `MODEL_NUM_CLASSES` constant for use by other modules + - Export `getModelVersion()` to distinguish v1 (38-class) from v2 (11,818-class) + +4. Handle edge cases: + - No model checkpoint available → fall back through v1 → mock + - CUDA/ROCm not available for ONNX → use CPU backend + - Model version mismatch → clear error message + +tests: + +- Integration: export model from checkpoint and verify output shape is [1, 11818] +- Integration: load exported model and run inference on a test image +- Unit: model loader graceful fallback chain (v2 → v1 → mock) + +acceptance_criteria: + +- Exported model produces 11,818 logits from a 224×224 image +- Model loader loads v2 model when available, falls back gracefully when not +- All existing v1 model consumers continue to work unmodified (via version detection) + +validation: + +- `node scripts/export-model.js` produces model files +- `npx tsc --noEmit` passes +- POST to `/api/identify` returns predictions (may be limited if species→disease label mapping not yet complete) + +notes: + +- This task is **blocked on model training completion**. The task file is the implementation spec; actual work begins after `species_final_final.pt` exists. +- The ONNX export path is preferred for server-side inference (no Python runtime needed once exported). +- If ONNX export quality degrades the output, export to TF.js SavedModel format instead. diff --git a/tasks/multi-image-user-feedback/04-enhanced-identify-api-route.md b/tasks/multi-image-user-feedback/04-enhanced-identify-api-route.md new file mode 100644 index 0000000..4a472ac --- /dev/null +++ b/tasks/multi-image-user-feedback/04-enhanced-identify-api-route.md @@ -0,0 +1,80 @@ +# 04. Enhanced API Route for Multi-Image, Species-Aware Identification + +meta: +id: multi-image-user-feedback-04 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01, multi-image-user-feedback-02, multi-image-user-feedback-03] +tags: [api, backend] + +objective: + +- Update the `/api/identify` route to accept optional `options` (secondImageId, userSpecies), run ensemble inference when multiple images provided, apply species-constrained softmax when species is known, and return top-5 species+disease combo predictions with confidence metadata. + +deliverables: + +- `src/app/api/identify/route.ts` — updated route handler + +steps: + +1. Update the request parsing to accept `options?: IdentifyOptions` alongside `imageId`: + + ```typescript + const { imageId, options } = body; + const secondImageId = options?.secondImageId; + const userSpecies = options?.userSpecies; + ``` + +2. Update image loading to support optional second image: + - Load first image tensor (existing logic) + - If `secondImageId` provided, load and preprocess that image too + - Validate both images exist before inference + +3. Update inference logic to use the new pipeline: + - If 2 images provided → call `runEnsembleInference(tensors, topK=10)` + - If `userSpecies` provided → get species class range, call `runSpeciesConstrainedInference` or `runEnsembleSpeciesConstrained` + - If 2 images + species → use full ensemble+constrained + - If 1 image + no species → use existing single-inference path (backward compatible) + - Pass the `infoProvided` list to the response (which optional inputs were used) + +4. Generate top-5 species+disease combo predictions: + - After enrichment, construct `TopPrediction[]` from the top enriched predictions + - Each entry: `{ speciesName, diseaseName, diseaseId, confidence, rank }` + - Include species confidence when `userSpecies` was provided + +5. Add `speciesConfidence` to response: + - When user provides species, compute how much the constraint improves confidence vs unconstrained + - Return both constrained and unconstrained confidence for comparison + +6. Handle demo/mock mode: + - When no real model loaded, return mock top-5 combos with appropriate demo_mode flag + - Mock combos should be realistic (use knowledge base to generate plausible species/disease pairs) + +tests: + +- Integration: single image + no options → existing behavior unchanged +- Integration: single image + species → constrained results, all predictions belong to that species +- Integration: 2 images → ensemble results, confidence should differ from single image +- Integration: 2 images + species → fully constrained ensemble +- Integration: missing secondImageId returns 400 error +- Integration: demo mode returns mock data + +acceptance_criteria: + +- Existing single-image identify flow works unchanged when `options` omitted +- When secondImageId is provided, inference runs on both images +- When userSpecies is provided, only diseases of that species are returned +- Top-5 species/disease combos are included in response +- Confidence reflects whether 1 or 2 images were used +- 400 error for invalid/missing second image + +validation: + +- `npx tsc --noEmit` passes +- Existing identify API tests pass: `npx vitest run src/app/api/identify/` +- Manual test with curl sending multi-image request + +notes: + +- The `infoProvided` array helps the UI show what data was used for the diagnosis +- When userSpecies is given but no model can restrict to that species (e.g., mock mode), fall back to filtering results client-side by the plant name diff --git a/tasks/multi-image-user-feedback/05-upload-page-second-image-species.md b/tasks/multi-image-user-feedback/05-upload-page-second-image-species.md new file mode 100644 index 0000000..ed28807 --- /dev/null +++ b/tasks/multi-image-user-feedback/05-upload-page-second-image-species.md @@ -0,0 +1,99 @@ +# 05. Upload Page with Optional Second Image and Species Selector + +meta: +id: multi-image-user-feedback-05 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01] +tags: [ui, frontend, upload] + +objective: + +- Enhance the upload page to support an optional second image upload and a species search/select field. +- Show live confidence indicators that update as the user provides more information. + +deliverables: + +- `src/app/upload/page.tsx` — updated upload flow +- `src/components/ImageUpload.tsx` — updated to support multiple uploads +- `src/components/SpeciesSelector.tsx` — new species search/dropdown component +- `src/components/ConfidencePreview.tsx` — new inline confidence indicator + +steps: + +1. Create `src/components/SpeciesSelector.tsx`: + - Searchable dropdown with all 531 plant species from the knowledge base + - Uses `fuse.js` for fuzzy matching (lightweight, fast client-side search) + - Props: `value: string | null`, `onChange: (species: string | null) => void`, `disabled?: boolean` + - States: empty, searching, selected, clearable + - Shows common name and scientific name in results + - Keyboard-navigable: up/down arrows, enter to select, escape to close + - Mobile-friendly with full-screen overlay on small screens + +2. Create `src/components/ConfidencePreview.tsx`: + - Small inline bar showing estimated confidence based on info provided + - Props: `numImages: number`, `speciesProvided: boolean`, `className?: string` + - Dynamics: + - 1 image, no species → "Low confidence — add another photo or identify the plant" + - 1 image + species → "Medium confidence" + - 2 images, no species → "Medium confidence — getting clearer" + - 2 images + species → "High confidence — good data for diagnosis" + - Animated transitions between states + - Uses the same color scheme as ConfidenceBadge (green/amber/red) + +3. Update `src/components/ImageUpload.tsx`: + - Change from single-image to multi-image upload flow + - After first successful upload, show a "Add another photo (optional)" button + - Second upload uses the same ImageUpload internals but is secondary in visual weight + - Store both `imageId` responses for the identify request + - Add `uploadedImages: UploadResponse[]` tracking + - Expose method to clear all images + +4. Update `src/app/upload/page.tsx`: + - Add state for: firstImageId, secondImageId, selectedSpecies + - Add `SpeciesSelector` component below the upload zone(s) + - Add `ConfidencePreview` component showing live confidence estimate + - Add a "Continue to Diagnosis" button that becomes more prominent as more info is provided + - On submit: + - Build `IdentifyOptions` with secondImageId and/or userSpecies if provided + - Pass `options` in the identify API call (or query params to results page) + - Navigate to `/results/{firstImageId}?options={encodedOptions}` + +5. Handle edge cases: + - User adds second image, then removes it → gracefully falls back to single-image + - User selects species, then wants to change it → searchable select supports re-selection + - No species match found → user can type free-form (stored as-is) + - Upload of second image fails → component shows inline error but allows retry without blocking the first image + +6. Optimistic UI guidance: + - Add a small info panel below the confidence preview explaining _why_ more info helps + - Text: "Adding a second photo from a different angle helps our AI make a more accurate diagnosis. Identifying the plant species narrows down the possible diseases." + +tests: + +- Unit: SpeciesSelector renders, searches, selects, clears +- Unit: SpeciesSelector keyboard navigation works +- Unit: ConfidencePreview renders correct messages for each combination +- Unit: ImageUpload supports 2-image flow +- Integration: Full upload flow with 2 images + species → verify all data in request +- A11y: verify aria-labels, roles, keyboard navigation + +acceptance_criteria: + +- User can upload a second image (optional, after first succeeds) +- User can search and select a plant species from a dropdown +- Confidence preview bar updates dynamically as info is added +- "Continue to Diagnosis" button is prominent once at least 1 image is uploaded +- Navigate to results page with all options encoded + +validation: + +- `npx tsc --noEmit` passes +- Manual test: upload flow with 0/1/2 images and with/without species +- Responsive test: works on mobile viewport (375px width) + +notes: + +- The species list comes from the knowledge base (`/api/plants` endpoint) +- `fuse.js` is already lightweight (~15KB gzipped) and can be client-imported +- The options are passed as URL query params to the results page since we navigate before the identify API call diff --git a/tasks/multi-image-user-feedback/06-dynamic-results-dashboard.md b/tasks/multi-image-user-feedback/06-dynamic-results-dashboard.md new file mode 100644 index 0000000..a643b66 --- /dev/null +++ b/tasks/multi-image-user-feedback/06-dynamic-results-dashboard.md @@ -0,0 +1,94 @@ +# 06. Results Dashboard with Dynamic Confidence and Top-5 Display + +meta: +id: multi-image-user-feedback-06 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01, multi-image-user-feedback-04, multi-image-user-feedback-05] +tags: [ui, frontend, results] + +objective: + +- Enhance the results dashboard to display an info panel showing what data the user provided (1/2 images, species) and how it affected confidence. +- Show top-5 species/disease combination predictions as a compact card stack. +- Animate confidence transitions when the user lands on results. + +deliverables: + +- `src/components/ResultsDashboard.tsx` — updated dashboard +- `src/components/InfoProvidedBanner.tsx` — new component showing what info was used +- `src/components/TopCombinationsCard.tsx` — new component for top-5 species/disease combo list + +steps: + +1. Create `src/components/InfoProvidedBanner.tsx`: + - Display a banner/panel at the top of results showing: + - Number of images analyzed (1 or 2) + - Whether user identified the plant species (yes/no, with species name if yes) + - Icons/checkmarks for each piece of info + - Show a compact breakdown: "You provided: 📸 2 images · 🌿 Species: Tomato" + - Props: `{ numImages: number, userSpecies?: string | null }` + - Style: subtle background, small text, positioned between page header and results + - Animate in with a fade-slide effect + +2. Create `src/components/TopCombinationsCard.tsx`: + - Display the top-5 species/disease combination predictions from the API response + - Each row: rank badge, disease name, plant name, confidence bar + - Clicking a row expands it to show full disease info (reuses DiseaseCard internals) + - Props: `{ predictions: TopPrediction[], onSelect: (diseaseId: string) => void }` + - States: + - Loading: skeleton rows + - Empty: no combinations available + - Error: graceful message + - Populated: ranked list with horizontal confidence bars + - Confidence bar: colored (green/amber/red) horizontal bar with percentage label + - The top-5 is filterable: user can toggle between "All diseases" and "Constrained to your species" (when species was provided) + +3. Update `src/components/ResultsDashboard.tsx`: + - Accept new props: `numImages: number`, `infoProvided: string[]`, `userSpecies?: string`, `topCombinations?: TopPrediction[]` + - Add `InfoProvidedBanner` at the top of the results area + - Add `TopCombinationsCard` in the right sidebar (below image preview on desktop) + - When `infoProvided` includes species, show a tag/badge: "Species identified: Tomato" with a lock icon (implying the results are constrained to that species) + - When the response contains a species confidence, show a "How confidence changes with more info" mini-accordion: + - "With 1 image: 65% confidence" + - "With 2 images: 72% confidence" + - "With 2 images + species: 88% confidence" + - This educates the user on the value of providing more info + +4. Animate confidence transitions: + - When results load, confidence badges count up from 0 to their final percentage + - Use CSS `@keyframes` for the count-up animation + - Duration: ~600ms with ease-out curve + - Only animate on initial load, not on re-renders + +5. Handle edge cases: + - No top combinations (no species match) → show message: "No common patterns found" + - Single image, no species → hide the "how confidence changes" section (nothing to compare) + - Single image with species → show comparison vs without species (estimate) + - Demo mode → show realistic mock combos + +tests: + +- Unit: InfoProvidedBanner renders correct icons for 1/2 images and species presence +- Unit: TopCombinationsCard renders ranked list and toggles between constrained/all +- Unit: confidence count-up animation triggers on mount +- Integration: full results page with all new sections renders correctly + +acceptance_criteria: + +- InfoProvidedBanner shows how many images and whether species was identified +- TopCombinationsCard shows top-5 predictions with confidence bars +- Confidence values count up on page load +- When species info is available, a "how confidence changes" section is visible +- All existing results functionality (DiseaseCard, SymptomChecker, etc.) still works + +validation: + +- `npx tsc --noEmit` passes +- Manual test: navigate to results with 2 images + species → verify all UI sections +- Manual test: navigate with 1 image no species → verify simplified UI + +notes: + +- The top-5 combos come from the identify API response's `topSpeciesDisease` field +- Confidence comparison values are estimated when the model hasn't been run with/without the constraint — the API provides both constrained and unconstrained confidence diff --git a/tasks/multi-image-user-feedback/07-post-evaluation-feedback-component.md b/tasks/multi-image-user-feedback/07-post-evaluation-feedback-component.md new file mode 100644 index 0000000..5c47674 --- /dev/null +++ b/tasks/multi-image-user-feedback/07-post-evaluation-feedback-component.md @@ -0,0 +1,111 @@ +# 07. Post-Diagnosis Feedback Component (Accuracy / Unsure / Store Consent) + +meta: +id: multi-image-user-feedback-07 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01, multi-image-user-feedback-06] +tags: [ui, frontend, feedback] + +objective: + +- Create a feedback panel that appears after the diagnosis results, asking the user to rate accuracy (✓ / ✗ / ?) and optionally consent to storing their images for model retraining. + +deliverables: + +- `src/components/PostDiagnosisFeedback.tsx` — new feedback component +- `src/components/ResultsDashboard.tsx` — updated to include feedback panel + +steps: + +1. Create `src/components/PostDiagnosisFeedback.tsx`: + + Component structure (vertically stacked in a card): + + ``` + ┌─────────────────────────────────────────────┐ + │ 💬 How accurate was this diagnosis? │ + │ │ + │ [ ✅ Correct ] [ ❌ Incorrect ] [ ❓ Unsure ] │ + │ │ + │ ── (if Incorrect or Unsure selected) ── │ + │ What did you expect? (optional) │ + │ [_____________________________] text input │ + │ │ + │ ──────────────────────────────────────────── │ + │ │ + │ ☐ Allow us to store these images to │ + │ improve future diagnoses? │ + │ (Your privacy matters — images stored │ + │ securely and never shared) │ + │ │ + │ [ Submit Feedback ] → sent to /api/feedback │ + │ │ + │ ───── (after submission) ───── │ + │ ✓ Thank you! Your feedback helps us improve. │ + └─────────────────────────────────────────────┘ + ``` + + Props: `{ sessionId: string, imageIds: string[], predictedDiseaseId: string, userSpecies?: string, modelVersion: string, onSubmit?: () => void }` + + States: + - **Pending**: not yet rated, three large buttons (✓/✗/?) + - **Rated**: accuracy selected, showing optional text input + consent checkbox + - **Submitting**: loading spinner on submit button + - **Submitted**: success message with thank-you text + - **Error**: submission failed, retry button + + Implementation details: + - Accuracy buttons are large and touch-friendly (min 48px tap target) + - Selected button fills with its color: green (✓), red (✗), amber (?) + - Text input is an optional free-text field for user comments + - Consent checkbox has a brief privacy notice below it + - Submit button disabled until accuracy is rated + - On submit, POST to `/api/feedback` with `DiagnosisFeedback` body + - Animated transitions between states + +2. Update `src/components/ResultsDashboard.tsx`: + - Import and render `PostDiagnosisFeedback` at the bottom of the results area + - Pass sessionId (generated from first imageId), imageIds, predictedDiseaseId, userSpecies + - Show feedback panel after all prediction cards + - If no predictions at all, still show feedback (they may want to tell us the model was wrong) + +3. Handle edge cases: + - Feedback submission fails → show inline error with retry + - User refreshes page → already-submitted state persists if submission completed (could use sessionStorage) + - Consent unchecked → still submit feedback (just with consent=false) + - No predictions returned → show feedback anyway with "No disease identified" context + +tests: + +- Unit: all four states render correctly (pending/rated/submitting/submitted) +- Unit: accuracy selection toggles correctly (only one selected at a time) +- Unit: submit button disabled until accuracy is rated +- Unit: consent checkbox unchecked by default +- Unit: text input only shown when accuracy is "incorrect" or "unsure" +- Unit: submission calls /api/feedback with correct payload shape +- Integration: feedback flow from rating to submission to success + +acceptance_criteria: + +- Three accuracy rating buttons are always visible after results +- Rating is required before submission +- Optional text input appears for "Incorrect" or "Unsure" ratings +- Consent checkbox allows opting in to image storage +- Submit sends correct payload to /api/feedback +- Success message shown after submission +- Error state with retry if submission fails + +validation: + +- `npx tsc --noEmit` passes +- Manual test: rate accuracy, type notes, toggle consent, submit +- Manual test: verify API receives correct data +- A11y: verify all interactive elements have accessible labels + +notes: + +- The `sessionId` ties together upload → identify → feedback for the same session +- Privacy notice text should be reviewed for legal compliance +- Consider adding a "Share with the community" option in a future iteration +- Debounce the submit button to prevent double-submission diff --git a/tasks/multi-image-user-feedback/08-feedback-api-endpoint.md b/tasks/multi-image-user-feedback/08-feedback-api-endpoint.md new file mode 100644 index 0000000..40be185 --- /dev/null +++ b/tasks/multi-image-user-feedback/08-feedback-api-endpoint.md @@ -0,0 +1,107 @@ +# 08. Feedback API Endpoint for Accuracy Ratings and Storage Consent + +meta: +id: multi-image-user-feedback-08 +feature: multi-image-user-feedback +priority: P1 +depends_on: [multi-image-user-feedback-01] +tags: [api, backend, database] + +objective: + +- Create a POST endpoint at `/api/feedback` that accepts diagnosis feedback submissions (accuracy rating, notes, image storage consent) and persists them to the database. + +deliverables: + +- `src/app/api/feedback/route.ts` — new API route +- `src/app/api/feedback/feedback.test.ts` — test file + +steps: + +1. Create `src/app/api/feedback/route.ts`: + + Route: `POST /api/feedback` + + Accepts JSON body matching `FeedbackRequest`: + + ```typescript + { + sessionId: string; + imageIds: string[]; + userSpecies?: string; + predictedDiseaseId: string; + accuracyRating: "correct" | "incorrect" | "unsure"; + consentToStoreImages: boolean; + userCorrectedSpecies?: string; + notes?: string; + } + ``` + + Handler logic: + - Parse and validate request body + - Generate UUID for `id` + - Get `modelVersion` from model loader's `getStatus()` + - Set `createdAt` to current timestamp + - Insert into `diagnosisFeedback` table via Drizzle + - Return `FeedbackResponse`: `{ success: true, id: string }` + - Handle validation errors with 400 status + - Handle DB errors with 500 status + + Validation rules: + - `sessionId` — required, non-empty string + - `imageIds` — required, array of non-empty strings, min length 1 + - `accuracyRating` — required, must be one of "correct", "incorrect", "unsure" + - `consentToStoreImages` — required, boolean + - `userSpecies` — optional string + - `userCorrectedSpecies` — optional string, only meaningful when accuracy is not "correct" + - `notes` — optional string, max 500 characters (with error message if exceeded) + +2. Create `src/lib/api/feedback.ts` — client-side helper: + - `submitFeedback(data: FeedbackRequest): Promise` + - POST to `/api/feedback` with JSON body + - 15-second timeout + - Handle network errors gracefully + +3. Handle edge cases: + - Invalid JSON body → 400 with descriptive error + - Missing required fields → 400 listing missing fields + - Invalid accuracyRating value → 400 with allowed values + - Database unreachable → 500 with error message + - Duplicate sessionId → allowed (user can submit multiple times for different predictions) + +4. CORS and caching: + - Add `Cache-Control: no-store` header + - No authentication required (public endpoint for feedback) + +tests: + +- Unit: valid feedback submission returns 200 with success +- Unit: missing required fields return 400 +- Unit: invalid accuracyRating returns 400 +- Unit: notes over 500 chars returns 400 +- Unit: empty imageIds array returns 400 +- Unit: client helper `submitFeedback()` makes correct fetch call +- Unit: client helper handles network error gracefully +- Integration: submit feedback and verify it exists in database + +acceptance_criteria: + +- POST /api/feedback accepts valid feedback and stores it +- Invalid requests return appropriate 400 errors with descriptive messages +- Database stores all fields correctly +- Client helper function is usable from any feedback component +- Endpoint returns `{ success: true, id }` on success + +validation: + +- `npx tsc --noEmit` passes +- Unit tests pass: `npx vitest run src/app/api/feedback/` +- Manual test: `curl -X POST http://localhost:3000/api/feedback -H 'Content-Type: application/json' -d '{"sessionId":"test","imageIds":["img1"],"predictedDiseaseId":"early-blight","accuracyRating":"correct","consentToStoreImages":false}'` +- Verify stored data with direct DB query + +notes: + +- No auth needed for MVP — feedback is public and anonymous +- imageIds reference images in the uploads directory; no automatic cleanup +- A future task could add a review/admin dashboard for browsing feedback entries +- Rate limiting could be added later if needed (by sessionId or IP) diff --git a/tasks/multi-image-user-feedback/README.md b/tasks/multi-image-user-feedback/README.md new file mode 100644 index 0000000..f21f296 --- /dev/null +++ b/tasks/multi-image-user-feedback/README.md @@ -0,0 +1,40 @@ +# Multi-Image Upload & User Feedback + +Objective: Allow users to upload a second optional image, manually identify the plant species, see dynamic confidence updates, top-5 predictions, and provide post-diagnosis feedback (accuracy rating + storage consent). + +Status legend: [ ] todo, [~] in-progress, [x] done + +## Tasks + +- [ ] 01 — Extend types and add feedback DB schema → `01-api-types-and-schema.md` +- [ ] 02 — Multi-image ensemble & species-constrained inference → `02-multi-image-inference-pipeline.md` +- [ ] 03 — Load trained Swin-Tiny model with species/disease routing → `03-hierarchical-model-loader.md` +- [ ] 04 — Enhanced API route for multi-image, species-aware identification → `04-enhanced-identify-api-route.md` +- [ ] 05 — Upload page with optional second image and species selector → `05-upload-page-second-image-species.md` +- [ ] 06 — Results dashboard with dynamic confidence and top-5 display → `06-dynamic-results-dashboard.md` +- [ ] 07 — Post-diagnosis feedback component (accuracy / unsure / store consent) → `07-post-evaluation-feedback-component.md` +- [ ] 08 — Feedback API endpoint for accuracy ratings and storage consent → `08-feedback-api-endpoint.md` + +## Dependencies + +``` +01 ← 02 (types needed by inference) +01 ← 04 (types needed by API route) +01 ← 05 (types needed by upload page) +01 ← 06 (types needed by results dashboard) +01 ← 07 (types needed by feedback component) +01 ← 08 (schema needed by feedback API) +02 ← 04 (inference pipeline needed by API route) +03 ← 04 (model loader needed by API route) +05 ← 06 (upload outcomes fed into results) +06 ← 07 (results shown before post-eval) +04 ← 06 (API response needed by results dashboard) +``` + +## Exit Criteria + +- User can optionally upload a second image and/or enter a species name +- Confidence scores dynamically reflect the amount of information provided +- Top-5 species/disease combo predictions are displayed +- After diagnosis, user can rate accuracy (✓/✗/?) and opt in to image storage for training +- Feedback is persisted to the database