From 34855eff550c381f7912ffffc0d7d4058adc9bd1 Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Fri, 12 Jun 2026 13:20:33 -0400 Subject: [PATCH] task to get this here done --- scripts/fill-training-dataset.ts | 267 ++++++++++----- .../01-dataset-reorganization.md | 169 ++++++++++ .../02-hierarchical-training.md | 309 ++++++++++++++++++ .../03-export-quantization.md | 165 ++++++++++ .../04-server-inference.md | 211 ++++++++++++ .../05-browser-hybrid.md | 208 ++++++++++++ tasks/hierarchical-model-upgrade/README.md | 63 ++++ 7 files changed, 1307 insertions(+), 85 deletions(-) create mode 100644 tasks/hierarchical-model-upgrade/01-dataset-reorganization.md create mode 100644 tasks/hierarchical-model-upgrade/02-hierarchical-training.md create mode 100644 tasks/hierarchical-model-upgrade/03-export-quantization.md create mode 100644 tasks/hierarchical-model-upgrade/04-server-inference.md create mode 100644 tasks/hierarchical-model-upgrade/05-browser-hybrid.md create mode 100644 tasks/hierarchical-model-upgrade/README.md diff --git a/scripts/fill-training-dataset.ts b/scripts/fill-training-dataset.ts index 5257da2..0d7bfbf 100644 --- a/scripts/fill-training-dataset.ts +++ b/scripts/fill-training-dataset.ts @@ -22,6 +22,7 @@ import "dotenv/config"; import { readFileSync, readdirSync, writeFileSync, existsSync, mkdirSync } from "fs"; import { resolve, extname } from "path"; +import { Agent, setGlobalDispatcher } from "undici"; // Load .env.development for DB creds const envPath = resolve(__dirname, "../.env.development"); @@ -41,7 +42,7 @@ try { } catch {} import { getDb, closeDb } from "@/lib/db/index"; -import { diseases } from "@/lib/db/schema"; +import { plants, diseases } from "@/lib/db/schema"; // ─── Config ───────────────────────────────────────────────────────────────── @@ -49,17 +50,18 @@ const DATASET_DIR = resolve(__dirname, "../data/dataset"); const SEEN_CACHE_FILE = resolve(DATASET_DIR, ".fill-seen-urls.json"); /** Target images per disease */ -const TARGET_PER_DISEASE = 200; +const TARGET_PER_DISEASE = 100; -/** Target images for the "healthy" class */ -const TARGET_HEALTHY = 400; +/** Target images per plant for the "healthy" class */ +const TARGET_HEALTHY_PER_PLANT = 400; /** * How many diseases to process in parallel. - * Each disease is I/O-bound (HTTP requests), so high concurrency is safe. - * The global DDG rate limiter prevents us from overwhelming DuckDuckGo. + * Reduced from 50 to 5 to prevent overwhelming undici's connection pool. + * Each disease fires multiple concurrent requests, so high concurrency causes + * connection exhaustion and undici crashes. */ -const DISEASE_CONCURRENCY = 50; +const DISEASE_CONCURRENCY = 5; /** * Max DDG requests per second (shared across all concurrent diseases). @@ -70,8 +72,11 @@ const DISEASE_CONCURRENCY = 50; */ const DDG_RATE_LIMIT_RPS = 6; -/** Max concurrent image downloads per disease */ -const CONCURRENT_DOWNLOADS = 50; +/** Max concurrent image downloads per disease. + * Reduced from 50 to 10 to prevent connection pool exhaustion. + * With 5 diseases × 10 downloads = 50 concurrent downloads (manageable). + */ +const CONCURRENT_DOWNLOADS = 10; /** Minimum image size in bytes to accept */ const MIN_IMAGE_SIZE = 10_000; // 10KB @@ -86,9 +91,16 @@ const ALLOWED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]; const UA = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_4 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.4 Mobile/15E148 Safari/604.1"; -/** Healthy class directory name */ +/** Healthy class parent directory name */ const HEALTHY_CLASS = "healthy"; +/** + * How many plants to process in parallel for healthy image collection. + * Each plant runs the same fillClass pipeline (3 sources), so keep this + * modest to avoid connection pool exhaustion. + */ +const MAX_HEALTHY_CONCURRENCY = 5; + /** How often (in diseases processed) to flush the seen-URLs cache to disk */ const SEEN_CACHE_FLUSH_INTERVAL = 20; @@ -98,9 +110,6 @@ const SEEN_CACHE_FLUSH_INTERVAL = 20; * the seen-URLs cache accumulates across runs. */ const MAX_DDG_PAGES = 5; -/** Healthy source queries limit */ -const MAX_HEALTHY_QUERIES = 20; - // ─── Types ────────────────────────────────────────────────────────────────── interface DuckDuckGoImageResult { @@ -169,6 +178,45 @@ const ddgLimiter = new TokenBucket(DDG_RATE_LIMIT_RPS); // ─── Helpers ──────────────────────────────────────────────────────────────── +/** + * Custom undici agent with connection pooling to prevent overwhelming the network stack. + * Limits concurrent connections per host to prevent undici crashes (the "onResponseError" bug). + * setGlobalDispatcher() makes ALL fetch() calls use this agent automatically. + */ +const fetchAgent = new Agent({ + connections: 50, // Max 50 concurrent connections total + connect: { + timeout: 10_000, // 10s connection timeout + }, + bodyTimeout: 30_000, // 30s body timeout + headersTimeout: 15_000, // 15s headers timeout + keepAliveTimeout: 30_000, // 30s keep-alive + keepAliveMaxTimeout: 60_000, // 60s max keep-alive +}); +setGlobalDispatcher(fetchAgent); + +/** + * Wrapper around global fetch() with retry + exponential backoff for transient errors. + * The connection pooling is handled by the global undici dispatcher (set above). + */ +async function safeFetch(url: string | URL, init?: RequestInit, retries = 2): Promise { + let lastError: Error | undefined; + for (let attempt = 0; attempt <= retries; attempt++) { + try { + return await fetch(url, init); + } catch (err: unknown) { + lastError = err instanceof Error ? err : new Error(String(err)); + // Don't retry on abort/timeout (user-initiated) + if (err instanceof Error && err.name === "AbortError") throw err; + if (attempt < retries) { + // Exponential backoff: 500ms, 2000ms + await sleep(500 * (attempt + 1) * (attempt + 1)); + } + } + } + throw lastError ?? new Error(`fetch failed after ${retries + 1} attempts`); +} + function sleep(ms: number): Promise { return new Promise((resolve) => setTimeout(resolve, ms)); } @@ -240,7 +288,7 @@ async function getVqdToken(query: string): Promise { const url = `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`; - const res = await fetch(url, { + const res = await safeFetch(url, { headers: { "User-Agent": UA, Accept: "text/html" }, signal: AbortSignal.timeout(15_000), }); @@ -267,7 +315,7 @@ async function searchImagesDuckDuckGo( query, )}&vqd=${vqd}&o=json&p=${page}&f=,,,`; - const res = await fetch(url, { + const res = await safeFetch(url, { headers: { "User-Agent": UA, Accept: "application/json", @@ -289,7 +337,7 @@ async function searchImagesDuckDuckGo( const freshVqd = await getVqdToken(query); await ddgLimiter.acquire(); const retryUrl = url.replace(/vqd=[^&]+/, `vqd=${freshVqd}`); - const retryRes = await fetch(retryUrl, { + const retryRes = await safeFetch(retryUrl, { headers: { "User-Agent": UA, Accept: "application/json", @@ -410,7 +458,7 @@ async function searchImagesInaturalist( `&order_by=observed_on&order=desc`; try { - const res = await fetch(apiUrl, { + const res = await safeFetch(apiUrl, { headers: { "User-Agent": UA, Accept: "application/json" }, signal: AbortSignal.timeout(15_000), }); @@ -462,7 +510,7 @@ async function searchImagesCommons( const url = `https://commons.wikimedia.org/w/api.php?${params}`; try { - const res = await fetch(url, { + const res = await safeFetch(url, { headers: { "User-Agent": UA }, signal: AbortSignal.timeout(10_000), }); @@ -500,7 +548,7 @@ async function searchImagesCommons( async function downloadImage(url: string, destPath: string): Promise { try { - const res = await fetch(url, { + const res = await safeFetch(url, { headers: { "User-Agent": UA, Accept: "image/webp,image/png,image/jpeg,*/*" }, signal: AbortSignal.timeout(8_000), }); @@ -667,16 +715,16 @@ async function fillClass( interface ScanResult { /** Disease id → how many images currently on disk */ diseaseCounts: Map; - /** How many healthy images on disk */ - healthyCount: number; + /** Plant id → how many healthy images currently on disk */ + healthyCounts: Map; } function scanDataset(): ScanResult { const diseaseCounts = new Map(); - let healthyCount = 0; + const healthyCounts = new Map(); if (!existsSync(DATASET_DIR)) { - return { diseaseCounts, healthyCount: 0 }; + return { diseaseCounts, healthyCounts }; } const entries = readdirSync(DATASET_DIR, { withFileTypes: true }); @@ -686,7 +734,16 @@ function scanDataset(): ScanResult { if (entry.name.startsWith(".")) continue; if (entry.name === HEALTHY_CLASS) { - healthyCount = countImagesInDir(resolve(DATASET_DIR, entry.name)); + // Scan per-plant subdirectories under healthy/ + const healthyDir = resolve(DATASET_DIR, entry.name); + const plantDirs = readdirSync(healthyDir, { withFileTypes: true }); + for (const pd of plantDirs) { + if (!pd.isDirectory() || pd.name.startsWith(".")) continue; + const count = countImagesInDir(resolve(healthyDir, pd.name)); + if (count > 0) { + healthyCounts.set(pd.name, count); + } + } } else { const count = countImagesInDir(resolve(DATASET_DIR, entry.name)); if (count > 0) { @@ -695,7 +752,7 @@ function scanDataset(): ScanResult { } } - return { diseaseCounts, healthyCount }; + return { diseaseCounts, healthyCounts }; } // ─── CLI Flags ────────────────────────────────────────────────────────────── @@ -722,11 +779,14 @@ async function main() { // ── Step 1: Scan what we already have ──────────────────────────────────── console.log("\nScanning existing dataset..."); - const { diseaseCounts, healthyCount } = scanDataset(); - console.log(` Found ${diseaseCounts.size} disease directories, ${healthyCount} healthy images`); + const { diseaseCounts, healthyCounts } = scanDataset(); + const totalHealthyImages = [...healthyCounts.values()].reduce((s, c) => s + c, 0); + console.log( + ` Found ${diseaseCounts.size} disease directories, ${healthyCounts.size} plants with healthy images (${totalHealthyImages} total)`, + ); - // ── Step 2: Load disease info from DB ──────────────────────────────────── - console.log("\nLoading disease info from database..."); + // ── Step 2: Load disease info and plant info from DB ───────────────────── + console.log("\nLoading data from database..."); const db = getDb(); const allDiseases = await db @@ -746,6 +806,15 @@ async function main() { } console.log(` Loaded ${diseaseInfo.size} unique diseases from DB`); + // Load all plants from DB for healthy class generation + const allPlants = await db + .select({ + id: plants.id, + commonName: plants.commonName, + }) + .from(plants); + console.log(` Loaded ${allPlants.length} plants from DB`); + // ── Step 3: Build deficit list ────────────────────────────────────────── const deficits: DiseaseInfo[] = []; @@ -764,14 +833,24 @@ async function main() { // direction when the front of the queue keeps hitting dead URLs) if (flags.reverse) deficits.reverse(); - const healthyDeficit = TARGET_HEALTHY - healthyCount; + // Build per-plant healthy deficits + const healthyDeficits: Array<{ id: string; have: number; needed: number }> = []; + for (const plant of allPlants) { + const have = healthyCounts.get(plant.id) ?? 0; + const needed = TARGET_HEALTHY_PER_PLANT - have; + if (needed > 0) { + healthyDeficits.push({ id: plant.id, have, needed }); + } + } + const totalHealthyDeficit = healthyDeficits.reduce((s, d) => s + d.needed, 0); console.log(`\n${"=".repeat(60)}`); console.log("DEFICIT REPORT"); console.log(`${"=".repeat(60)}`); console.log(` Diseases needing images: ${deficits.length}/${diseaseInfo.size}`); console.log(` Total images missing: ${deficits.reduce((s, d) => s + d.needed, 0)}`); - console.log(` Healthy deficit: ${Math.max(0, healthyDeficit)}`); + console.log(` Plants needing healthy: ${healthyDeficits.length}/${allPlants.length}`); + console.log(` Total healthy images missing: ${totalHealthyDeficit}`); console.log(` Parallelism: ${DISEASE_CONCURRENCY} diseases at once`); console.log(` DDG rate limit: ${DDG_RATE_LIMIT_RPS} req/s (shared)`); console.log( @@ -779,7 +858,7 @@ async function main() { ); console.log(`${"=".repeat(60)}`); - if (deficits.length === 0 && healthyDeficit <= 0) { + if (deficits.length === 0 && healthyDeficits.length === 0) { console.log("\n ✓ Nothing to do — all targets met!\n"); await closeDb(); return; @@ -788,7 +867,6 @@ async function main() { // ── Step 4: Load seen-URLs cache ──────────────────────────────────────── const seenUrlsCache = loadSeenUrlsCache(); let totalDownloaded = 0; - let totalFailed = 0; let diseasesProcessed = 0; const startTime = Date.now(); @@ -875,64 +953,78 @@ async function main() { } } - // ── Step 6: Fill healthy deficit ──────────────────────────────────────── - if (healthyDeficit > 0) { + // ── Step 6: Fill healthy deficits per plant ──────────────────────────── + if (healthyDeficits.length > 0) { console.log("\n" + "─".repeat(60)); - console.log(`FILLING HEALTHY CLASS (target: ${TARGET_HEALTHY})`); + console.log( + `FILLING ${healthyDeficits.length} PLANTS\' HEALTHY CLASSES` + + ` (target: ${TARGET_HEALTHY_PER_PLANT} each)`, + ); console.log("─".repeat(60)); - const healthyDir = resolve(DATASET_DIR, HEALTHY_CLASS); - mkdirSync(healthyDir, { recursive: true }); + let healthyProcessed = 0; - // Collect all unique plants from the disease info - const allPlants = [...new Set(diseaseInfo.values())].map((d) => d.plantId); - const allHealthyQueries: string[] = []; - for (const plant of allPlants) { - allHealthyQueries.push(...buildHealthyQueries(plant)); - } + // Process in parallel batches + for (let i = 0; i < healthyDeficits.length; i += MAX_HEALTHY_CONCURRENCY) { + const batch = healthyDeficits.slice(i, i + MAX_HEALTHY_CONCURRENCY); + const batchNum = Math.floor(i / MAX_HEALTHY_CONCURRENCY) + 1; + const totalBatches = Math.ceil(healthyDeficits.length / MAX_HEALTHY_CONCURRENCY); - const healthySeen = new Set(seenUrlsCache[HEALTHY_CLASS] ?? []); - const healthyNeeded = TARGET_HEALTHY - countImagesInDir(healthyDir); + console.log(`\n[Batch ${batchNum}/${totalBatches}] Processing ${batch.length} plants...`); - // Run all 3 sources in parallel for the healthy class too - const [ddgUrls, inatUrls, commonsUrls] = await Promise.allSettled([ - collectImagesDuckDuckGo( - allHealthyQueries.slice(0, MAX_HEALTHY_QUERIES), - healthyNeeded, - healthySeen, - ), - searchImagesInaturalist(allHealthyQueries[0], healthyNeeded, healthySeen), - searchImagesCommons(allHealthyQueries[0], healthyNeeded, healthySeen), - ]); + const STAGGER_MS = 200; + const batchResults = await Promise.allSettled( + batch.map((p, idx) => + (async () => { + if (idx > 0) await sleep(idx * STAGGER_MS); - const allUrls: string[] = []; - for (const settled of [ddgUrls, inatUrls, commonsUrls]) { - if (settled.status === "fulfilled") { - allUrls.push(...settled.value.urls); - } - } + const classDir = resolve(DATASET_DIR, HEALTHY_CLASS, p.id); + const queries = buildHealthyQueries(p.id); + const seen = new Set(); - if (allUrls.length > 0) { - console.log(`\n Downloading ${allUrls.length} healthy images...`); - const startIdx = countImagesInDir(healthyDir); - const { downloaded, failed } = await downloadBatch(allUrls, healthyDir, startIdx); + console.log(` [healthy/${p.id}] have ${p.have}, need ${p.needed} more`); - const newTotal = countImagesInDir(healthyDir); - const gained = newTotal - healthyCount; - totalDownloaded += gained; - totalFailed += failed; + const gained = await fillClass(`healthy/${p.id}`, queries, p.needed, classDir, seen); - console.log( - ` ${downloaded > 0 ? "✓" : "✗"} Got ${downloaded} images.` + - ` Total healthy: ${newTotal}/${TARGET_HEALTHY} (${gained} new)`, + // Update seen-URLs cache for this plant's healthy images + const cacheKey = `${HEALTHY_CLASS}/${p.id}`; + const existing = seenUrlsCache[cacheKey] ?? []; + const merged = [...new Set([...existing, ...Array.from(seen)])]; + seenUrlsCache[cacheKey] = merged.slice(-500); + return gained; + })(), + ), ); - } else { - console.log(`\n ✗ No healthy images found`); - } - // Update seen-URLs cache - seenUrlsCache[HEALTHY_CLASS] = Array.from(healthySeen); - saveSeenUrlsCache(seenUrlsCache); + for (const result of batchResults) { + if (result.status === "fulfilled") { + totalDownloaded += result.value; + } else { + console.error(` ✗ Plant healthy fill failed: ${result.reason}`); + } + } + + healthyProcessed += batch.length; + + // Flush seen-URLs cache to disk periodically + if ( + healthyProcessed % SEEN_CACHE_FLUSH_INTERVAL < batch.length || + i + batch.length >= healthyDeficits.length + ) { + saveSeenUrlsCache(seenUrlsCache); + } + + const elapsed = Math.round((Date.now() - startTime) / 1000); + const rate = healthyProcessed / Math.max(1, elapsed); + const remaining = healthyDeficits.length - healthyProcessed; + const eta = remaining / Math.max(0.01, rate); + console.log( + ` [Batch ${batchNum}/${totalBatches}] checkpoint — ` + + `${totalDownloaded} downloaded (healthy), ` + + `${healthyProcessed}/${healthyDeficits.length} plants (${rate.toFixed(1)}/s, ` + + `ETA: ${Math.round(eta)}s)`, + ); + } } // ── Summary ────────────────────────────────────────────────────────────── @@ -946,16 +1038,21 @@ async function main() { const atTarget = [...finalScan.diseaseCounts.values()].filter( (c) => c >= TARGET_PER_DISEASE, ).length; + const healthyAtTarget = [...finalScan.healthyCounts.values()].filter( + (c) => c >= TARGET_HEALTHY_PER_PLANT, + ).length; + const totalHealthyFinal = [...finalScan.healthyCounts.values()].reduce((s, c) => s + c, 0); console.log("\n" + "=".repeat(60)); console.log(" ✅ FILL COMPLETE"); console.log("=".repeat(60)); - console.log(` Time: ${hrs}h ${mins % 60}m`); - console.log(` Diseases at target: ${atTarget}/${diseaseInfo.size}`); - console.log(` Total images: ${totalHave}`); - console.log(` Healthy images: ${finalScan.healthyCount}/${TARGET_HEALTHY}`); - console.log(` New downloads: ${totalDownloaded}`); - console.log(` Dataset dir: ${DATASET_DIR}/`); + console.log(` Time: ${hrs}h ${mins % 60}m`); + console.log(` Diseases at target: ${atTarget}/${diseaseInfo.size}`); + console.log(` Total disease images: ${totalHave}`); + console.log(` Plants at healthy target: ${healthyAtTarget}/${allPlants.length}`); + console.log(` Total healthy images: ${totalHealthyFinal}`); + console.log(` New downloads: ${totalDownloaded}`); + console.log(` Dataset dir: ${DATASET_DIR}/`); await closeDb(); console.log("=".repeat(60)); diff --git a/tasks/hierarchical-model-upgrade/01-dataset-reorganization.md b/tasks/hierarchical-model-upgrade/01-dataset-reorganization.md new file mode 100644 index 0000000..d2e3437 --- /dev/null +++ b/tasks/hierarchical-model-upgrade/01-dataset-reorganization.md @@ -0,0 +1,169 @@ +# Phase 1 — Dataset Reorganization + +**Blocked by**: Nothing +**Blocks**: Phase 2 (training) +**Est. time**: 2-3 days +**Machine**: Strix Halo (fast I/O, 128GB RAM for in-memory processing) + +## Objective + +Reorganize the flat directory structure (`data/dataset/plant-disease-name/`) into a proper hierarchical layout (`data/organized/species/disease/`) with train/val splits and metadata files. + +## Current State + +- `data/dataset/` — 11,499 flat directories, each named `{plant}-{disease}` +- Files mixed: `.jpg`, `.jpeg`, `.png`, `.webp` per directory +- Total: ~1.47M images, 64-244 images per class (well-balanced) +- **Total size: ~450 GB** +- **SSD available**: 8TB NVMe (7,300 MB/s read, 6,300 MB/s write — PCIe 5.0) + +## Deliverables + +``` +data/ +├── organized/ +│ ├── train/ # 85% of images +│ │ ├── {species_1}/ +│ │ │ ├── healthy/ +│ │ │ ├── {disease_a}/ +│ │ │ └── {disease_b}/ +│ │ ├── {species_2}/ +│ │ └── ... +│ ├── val/ # 15% of images +│ │ └── ... (mirrors train structure) +│ ├── species_index.json # Maps species → [disease IDs] +│ ├── class_hierarchy.json # Full mapping + metadata +│ └── dataset_stats.json # Counts per class, splits +``` + +## Steps + +### 1.1 Parse directory names → (species, disease) pairs + +**Problem**: Directory names like `acorn-squash-powdery-mildew` use an inconsistent separator (hyphen). Need to reliably split plant name from disease name. + +**Approach**: Use `src/data/diseases.json` as ground truth. Try matching each directory name against known disease ID suffixes (sorted longest-first). The remainder is the plant name. + +**Fallback**: For unmatched dirs, build a plant suffix list from `src/data/plants.json` and try prefix matching. Log any truly unmatched dirs for manual review. + +**Script**: `scripts/organize-dataset.py` + +```python +# Pseudocode for the matching algorithm: +disease_ids = sorted([d["id"] for d in diseases], key=len, reverse=True) +plant_names = [p["id"] for p in plants] # or extract from dir prefixes + +for dir_name in dataset_dirs: + matched_disease = next(d for d in disease_ids if dir_name.endswith(d)) + plant = dir_name[:-(len(matched_disease)+1)] # +1 for hyphen + hierarchy[plant].append(matched_disease) +``` + +### 1.2 Split into train/val (85/15) + +Use stratified splitting per class to preserve class distribution. + +- For each disease-plant class, randomly assign 85% to train, 15% to val +- Copy files (or symlink) to new directory structure +- Verify no data leakage (same image in both splits) + +### 1.3 Build metadata files + +```json +// species_index.json +{ + "tomato": ["healthy", "early-blight", "late-blight", "bacterial-spot", ...], + "acorn-squash": ["healthy", "powdery-mildew", "downy-mildew", ...], + ... +} + +// dataset_stats.json +{ + "total_images": 1465818, + "total_species": 320, + "total_classes": 11499, + "images_per_class": { "min": 64, "max": 244, "mean": 127 }, + "train_images": 1245945, + "val_images": 219873, + "species_disease_counts": { + "tomato": { "early-blight": 156, "late-blight": 142, ... } + } +} +``` + +### 1.4 Data quality checks + +### 1.4 Image normalization & compression (before splitting) + +**450GB is unnecessarily large for 224px training.** Many source images are high-resolution (e.g., 4000×3000 from phone cameras), but the model only sees 224×224 crops. Resizing to a reasonable max dimension BEFORE training saves massive I/O and enables faster epochs. + +**Strategy**: Resize all images to **max dimension of 512px** (preserving aspect ratio), convert to **JPEG quality 90**. + +| Approach | Est. Size | Pros | Cons | +| ------------------------------- | ------------- | -------------------------------------------------- | ---------------------------------- | +| **Keep originals** | 450 GB | No quality loss | Slow loading, huge storage | +| **Resize 1024px max, JPEG 90** | ~120 GB | Good for future higher-res models | Still somewhat large | +| **Resize 512px max, JPEG 90 ✓** | **~60-80 GB** | **Fast loading, enough detail for 224px training** | Can't go back to full res | +| **Resize 256px max, JPEG 95** | ~30 GB | Fastest loading | Too small if retrain at higher res | + +**Recommendation**: Resize to 512px max, JPEG q90. This: + +- Reduces storage from 450GB → ~70GB (fits in RAM for caching) +- Preserves enough detail for 224×224 RandomResizedCrop augmentation +- JPEG is hardware-accelerated (libjpeg-turbo) — fastest decode path +- Single format (no more .png/.webp mixed loading) + +```python +# resize_and_convert.py +from PIL import Image +import os +from joblib import Parallel, delayed + +MAX_SIZE = 512 +QUALITY = 90 + +def process_image(src_path, dst_path): + img = Image.open(src_path) + # Resize so max dimension = MAX_SIZE, preserving aspect ratio + w, h = img.size + if max(w, h) > MAX_SIZE: + ratio = MAX_SIZE / max(w, h) + img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS) + # Convert to RGB (handles RGBA PNGs) + if img.mode != 'RGB': + img = img.convert('RGB') + # Save as JPEG + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + img.save(dst_path, 'JPEG', quality=QUALITY, optimize=True) + +# Run in parallel (Strix Halo has many cores) +Parallel(n_jobs=16)( + delayed(process_image)(src, dst) + for src, dst in image_pairs +) +``` + +**Time estimate on Strix Halo**: ~2-3 hours to resize + convert 1.47M images with 16 parallel workers. Each image takes ~5-10ms with PIL+LANCZOS. + +### 1.5 Data quality checks + +- **Label noise**: Run confidence learning (CleanLab) on a sample to estimate mislabel rate. Web-scraped datasets typically have 8-15% label noise. +- **Duplicate detection**: Check for near-duplicate images (perceptual hashing + Hamming distance) within each class. +- **Format consistency**: Ensure all images decode successfully; remove corrupted files. +- **Background bias**: Verify that no single background dominates a class (subset and eyeball a random grid per class). + +## Edge Cases & Gotchas + +- **Multi-word plant names**: "acorn-squash", "fiddle-leaf-fig", "chili-pepper" — the disease suffix must match the end of the string, not a substring in the plant name. Sorting disease IDs by length (longest first) handles this. +- **Disease-less "healthy" dirs**: Need to ensure "healthy" is in the disease list as a valid class (index 0 in current model). Some dirs may be `{plant}-healthy`. +- **Cross-platform path length**: Some species+disease combos produce long paths. Use relative symlinks or shorten names if needed on Windows. +- **Original files preserved**: The existing `data/dataset/` structure stays untouched; `data/organized/` is a copy. + +## Verification + +- [ ] `data/organized/train/` has same total image count as original (minus val split) +- [ ] Every class has at least 50 training images +- [ ] `species_index.json` covers all 11,499 classes +- [ ] No files in both train/ and val/ (no overlap) +- [ ] All images readable (no corrupted files) +- [ ] Train/val split ratios consistent across all classes (±2%) diff --git a/tasks/hierarchical-model-upgrade/02-hierarchical-training.md b/tasks/hierarchical-model-upgrade/02-hierarchical-training.md new file mode 100644 index 0000000..e6c22ef --- /dev/null +++ b/tasks/hierarchical-model-upgrade/02-hierarchical-training.md @@ -0,0 +1,309 @@ +# Phase 2 — Hierarchical Model Training + +**Blocked by**: Phase 1 (dataset reorganization) +**Blocks**: Phase 3 (export) +**Est. time**: 3-5 days on Strix Halo (ROCm), or 4-6 days on RTX 3090 (CUDA) +**Machine**: Strix Halo preferred (128GB unified memory + 8TB NVMe at 7,300 MB/s read — SSD is fast enough to stream entire dataset in ~62s) + +## Objective + +Train a hierarchical Swin-Tiny model with two classification heads: + +1. **Species head** (~320 classes) — identifies the plant +2. **Disease heads** (one per species, 30-300 classes each) — identifies the disease + +## Architecture + +``` +Input Image (224×224×3) + │ + ▼ +┌──────────────────────┐ +│ Swin-Tiny Backbone │ ← pretrained on ImageNet-21K +│ (timm library) │ optional: fine-tune on iNaturalist +│ output: 768-dim │ +└──────────┬───────────┘ + │ + ┌──────┴──────┐ + ▼ ▼ +┌─────────┐ ┌───────────────┐ +│ Species │ │ Disease Head │ +│ Head │ │ (routed by │ +│ 320 cls │ │ species ID) │ +└────┬────┘ └───────┬───────┘ + │ │ + ▼ ▼ + Species ID Disease ID +``` + +## Environment Setup + +```bash +# On Strix Halo (ROCm) +python3 -m venv .hierarchical-venv +source .hierarchical-venv/bin/activate + +# ROCm PyTorch (install from https://pytorch.org/get-started/locally/) +# ROCm 6.x + PyTorch 2.5+ +pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2 + +# Training libs +pip install pytorch-lightning timm transformers wandb +pip install albumentations opencv-python pillow + +# Data loading (for SSD-optimized streaming) +pip install webdataset fsspec # optional benchmarks +``` + +**Alternative (RTX 3090 CUDA path)**: + +```bash +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +``` + +## Training Protocol + +### Stage A: Species Classifier (2 days) + +| Step | Epochs | LR | Batch Size | Details | +| -------------- | ------ | ----------- | ------------------------ | ----------------------------------------------- | +| Head warmup | 5 | 1e-3 | 512 (Strix) / 256 (3090) | Backbone frozen, train only species head | +| Full fine-tune | 20 | 1e-4 → 1e-6 | 512 / 256 | Unfreeze backbone, cosine LR schedule | +| Stage final | 5 | 5e-6 | 512 / 256 | Discriminative LR: backbone layers 0.1× head LR | + +**Loss**: Focal Loss (γ=2.0, α=0.25) — handles any class imbalance in species distribution. + +**Augmentation — Image Jittering, Degradation & Robustness**: + +Real-world plant photos vary dramatically: different cameras, lighting conditions, angles, weather, focus quality, and compression artifacts. **Augmentation is not optional — it's essential for generalization.** The more varied your augmentation, the more robust your model will be when deployed. + +**Three tiers of augmentation**, all applied on-the-fly (never pre-generated): + +#### Tier 1 — Core Geometric & Photometric (applied to every image) + +These simulate the most common real-world variations: + +| Augmentation | Parameter | Simulates | +| ------------------------ | ----------------------------------------------------- | -------------------------------------------------- | +| RandomResizedCrop | scale=(0.6, 1.0), ratio=(0.75, 1.33) | Different shooting distances, zoom levels, framing | +| HorizontalFlip | p=0.5 | Different leaf orientations (left/right symmetry) | +| Rotate | limit=45°, p=0.5 | Off-angle photos, tilted camera | +| ColorJitter | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 | Different lighting — sunny, overcast, shade, dusk | +| RandomBrightnessContrast | brightness_limit=0.2, contrast_limit=0.2, p=0.5 | Exposure variations from auto-exposure cameras | + +#### Tier 2 — Degradation & Quality Simulation (applied to ~30% of images) + +These make the model robust to poor-quality inputs that real users will upload: + +| Augmentation | Parameter | Simulates | +| ---------------- | ----------------------------------------- | --------------------------------------------------------- | +| GaussianBlur | blur_limit=(3, 7), p=0.2 | Out-of-focus photos, motion blur | +| GaussianNoise | var_limit=(10, 50), p=0.15 | Low-light sensor noise, phone camera noise | +| ImageCompression | quality_lower=60, quality_upper=95, p=0.2 | JPEG artifacts from compression, social media re-encoding | +| RandomGrayscale | p=0.05 | Monochrome cameras, infrared plant imaging | +| RandomShadow | shadow_roi=(0, 1, 0, 1), p=0.15 | Leaves in shadow of other leaves/structures outdoors | + +#### Tier 3 — Advanced Regularization (applied at batch level) + +These are cutting-edge techniques that significantly improve generalization on fine-grained classification: + +| Technique | Parameter | Effect | +| ------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| **MixUp** | α=0.2 | Blends two random images + labels linearly. Forces the model to learn smooth decision boundaries. Proven +4.8% improvement on rare plant diseases. | +| **CutMix** | α=1.0 | Replaces a random patch of one image with another. Forces the model to focus on local lesion features rather than overall leaf shape. | +| **RandAugment** | N=2, M=9 | Auto-selected augmentation policy. 14 operations randomly chosen (shear, translate, rotate, contrast, etc.). N=2 ops per image, magnitude 9 (on 0-10 scale). | +| **Label Smoothing** | ε=0.1 | Prevents overconfidence on training classes, improves calibration on unseen diseases. | + +**Implementation (albumentations)**: + +```python +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import kornia.augmentation as K # for GPU-based MixUp/CutMix + +# Core spatial + photometric (Tier 1+2) +train_transform = A.Compose([ + A.RandomResizedCrop(224, 224, scale=(0.6, 1.0), ratio=(0.75, 1.33)), + A.HorizontalFlip(p=0.5), + A.Rotate(limit=45, p=0.5), + A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8), + # Degradation (Tier 2) — only some images get these + A.OneOf([ + A.GaussianBlur(blur_limit=(3, 7), p=1.0), + A.GaussianNoise(var_limit=(10, 50), p=1.0), + A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0), + ], p=0.3), + A.ImageCompression(quality_lower=60, quality_upper=95, p=0.2), + A.RandomShadow(shadow_roi=(0, 0.5, 0.5, 1), num_shadows_lower=1, num_shadows_upper=2, p=0.15), + A.RandomGrayscale(p=0.05), + # Normalize + A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ToTensorV2(), +]) + +# Validation — minimal, deterministic +val_transform = A.Compose([ + A.Resize(224, 224), + A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ToTensorV2(), +]) +``` + +**MixUp/CutMix (implemented in training loop, applied on GPU)**: + +```python +def mixup_criterion(criterion, pred, y_a, y_b, lam): + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + +for images, labels in dataloader: + images, labels = images.to(device), labels.to(device) + + if use_mixup and np.random.random() < 0.5: + # MixUp: blend images and labels + lam = np.random.beta(0.2, 0.2) + indices = torch.randperm(images.size(0)).to(device) + mixed_images = lam * images + (1 - lam) * images[indices] + + logits = model(mixed_images) + loss = mixup_criterion(criterion, logits, labels, labels[indices], lam) + else: + logits = model(images) + loss = criterion(logits, labels) +``` + +## SSD Data Loading Strategy + +**Important**: The full dataset is ~70GB (after Phase 1 resizing), which exceeds the 128GB RAM when accounting for OS, GPU memory, workspace, and model weights. However, **your 8TB NVMe at 7,300 MB/s read changes everything.** + +| Metric | Value | +| ----------------------------------------- | --------------- | +| Dataset size (after resize) | ~70 GB | +| NVMe read speed | 7,300 MB/s | +| Sequential read time (full dataset) | **~10 seconds** | +| Random read (1000 random files, 4KB each) | ~0.5ms seek | + +**Key insight**: The GPU consumes batches slower than the SSD can deliver them. With `num_workers=8`, each worker reads ~35 images/s from random positions. At 7,300 MB/s sequential, the SSD can serve 150,000+ images/s. The bottleneck is **JPEG decode + augmentation**, not disk I/O. + +**Recommended DataLoader configuration**: + +```python +dataloader = DataLoader( + dataset, + batch_size=256, + shuffle=True, + # SSD-optimized settings: + num_workers=8, # 8 parallel readers — enough to saturate GPU + prefetch_factor=4, # Each worker prefetches 4 batches ahead + pin_memory=True, # Faster CPU→GPU transfer via DMA + persistent_workers=True, # Keep workers alive between epochs (avoid fork overhead) + drop_last=True, # Drop incomplete final batch for consistent batch norm +) +``` + +**Why not load everything into RAM?** + +- 128GB total memory — after OS (~8GB), GPU reserved (~4-8GB ROCm), model weights + optimizer states (~4GB), augmentation workspace (~2GB), you have ~100GB free +- 70GB dataset would barely fit, but leaves no room for caching augmentation results or handling spikes +- Better approach: let the NVMe + DataLoader pipeline stream data. At 7,300 MB/s, reading a batch of 256 images (~50MB) takes ~7ms. Meanwhile, the GPU takes ~200ms to process that batch. **The disk is 30× faster than the GPU — you will never be I/O bound.** + +**Optional: Use WebDataset for maximum throughput** + +WebDataset shards the dataset into large tar files (~1GB each), which are sequentially read. This eliminates random seek overhead entirely — ideal when running at massive scale. For your setup it's optional (raw files on NVMe are already fast enough), but worth considering if you scale to multi-GPU: + +```bash +pip install webdataset +``` + +```python +import webdataset as wds + +urls = "data/organized/train/shard-{000000..000099}.tar" +dataset = wds.WebDataset(urls).shuffle(10000).decode("pil").to_tuple("jpg", "cls").map(augment) +``` + +**Profiling check**: During training, monitor GPU utilization: + +- `nvidia-smi` / `rocm-smi` — GPU-Util should be >90% +- If <70%, GPU is waiting for data → increase `num_workers` or `prefetch_factor` +- If >95%, data pipeline is keeping up → optimal + +### Stage B: Disease Classifiers (2-3 days) + +| Step | Epochs | LR | Details | +| ----------------- | ------ | ---- | ------------------------------------------------------- | +| All disease heads | 15 | 1e-3 | Backbone frozen, train all disease heads simultaneously | +| Rare-class boost | 5 | 5e-4 | Oversample classes with <80 images | +| End-to-end | 10 | 1e-5 | Unfreeze backbone, joint species + disease loss | + +**Key design**: Disease heads are simple linear layers (768 → num_diseases_for_species). Since they share the backbone, inference is efficient — one forward pass through Swin-Tiny, then route the 768-dim feature vector to the correct head. + +**Class balancing**: Use weighted sampler for disease heads — classes with <80 images get 3× sampling weight, classes with <50 images get 5×. + +**Loss weighting**: `L_total = L_species + 0.7 * L_disease` — species loss has higher weight since disease prediction depends on correct species ID. + +## Model Checkpointing + +``` +checkpoints/ +├── species_only/ # Stage A checkpoints +│ ├── epoch=05-val_loss=0.42.ckpt +│ ├── epoch=15-val_loss=0.18.ckpt +│ └── epoch=25-best.ckpt +├── disease_heads/ # Stage B initial disease heads +│ ├── disease_heads_epoch=15.pt +│ └── disease_heads_final.pt +└── hierarchical_full/ # End-to-end + ├── epoch=05.ckpt + └── epoch=10-best.ckpt +``` + +Save every checkpoint with species accuracy, macro F1, and per-disease F1 for the tail classes. + +## Expected Training Time (RTX 3090 baseline) + +| Stage | Epochs | Time/Epoch | Total | +| ----------------------- | ------ | ---------- | ---------- | +| Species head warmup | 5 | ~18 min | 1.5 hr | +| Species full fine-tune | 20 | ~45 min | 15 hr | +| Species fine-tune final | 5 | ~45 min | 3.75 hr | +| Disease heads | 15 | ~30 min | 7.5 hr | +| Disease rare-class | 5 | ~35 min | 3 hr | +| End-to-end | 10 | ~50 min | 8.5 hr | +| **Total** | **60** | | **~39 hr** | + +On **Strix Halo with NVMe + ROCm** the time/epoch should be **significantly faster** due to: + +- 7,300 MB/s NVMe: data loads faster than GPU can consume it (zero I/O wait) +- Larger batch sizes (512 vs 256): fewer iterations per epoch +- ROCm 6.x has strong PyTorch performance on AMD GPUs +- 128GB RAM allows large prefetch buffer for seamless streaming + +Expect **~20-28 hours** total on Strix Halo. + +## Evaluation Metrics + +| Metric | Target | Measurement | +| -------------------------------- | ------ | --------------------------------------- | ---------------- | +| Species top-1 accuracy | ≥95% | Fraction of correct species predictions | +| Disease top-1 accuracy | ≥88% | Across all species-conditioned heads | +| Disease top-3 accuracy | ≥94% | Model correct if disease is in top 3 | +| Macro F1 (rare diseases) | ≥80% | Weighted average across tail classes | +| Species→Disease cascade accuracy | ≥90% | P(correct species) × P(correct disease | correct species) | + +## Edge Cases & Gotchas + +- **GPU memory on RTX 3090 (24GB)**: Swin-Tiny at 224px with batch size 256 + mixed precision should fit. If not, reduce to 128 or use gradient accumulation (accumulate 2 steps). +- **Strix Halo ROCm quirks**: `torch.compile()` may have issues on ROCm 6.2 — test without it first. Some `timm` model ops may need ROCm kernel fallbacks; test the forward pass before starting training. +- **Checkpoint compatibility**: Save in pure PyTorch format (`.pt`), not Lightning-specific, so they're loadable outside Lightning for export. +- **Disease head memory**: 320 separate linear layers sounds large, but each is 768×N_diseases (avg ~300 → 230K params). Total disease head params: ~70M (vs 28M for backbone). This is fine — compute is dominated by the backbone. +- **Loss divergence on rare diseases**: Monitor individual disease loss curves; if a tail class diverges, reduce its learning rate or use gradient clipping (max_norm=1.0). + +## Verification + +- [ ] Species classifier ≥95% top-1 on val set +- [ ] Disease top-3 accuracy ≥94% on val set +- [ ] Confusion matrix shows no systematic species misclassifications +- [ ] Per-species disease classifiers all converge (no NaN losses) +- [ ] Tail classes (≤80 images) have F1 ≥70% +- [ ] Model can be loaded from checkpoint and run inference in PyTorch +- [ ] No OOM errors during training diff --git a/tasks/hierarchical-model-upgrade/03-export-quantization.md b/tasks/hierarchical-model-upgrade/03-export-quantization.md new file mode 100644 index 0000000..2ba8fd3 --- /dev/null +++ b/tasks/hierarchical-model-upgrade/03-export-quantization.md @@ -0,0 +1,165 @@ +# Phase 3 — ONNX Export & Quantization + +**Blocked by**: Phase 2 (trained model) +**Blocks**: Phase 4 (server inference) +**Est. time**: 1-2 days +**Machine**: Any (RTX 3090 recommended for ONNX GPU validation) + +## Objective + +Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment. + +## Deliverables + +``` +public/models/ +├── swin-species.onnx # FP16 species model (3.2 MB) +├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB) +├── disease-heads/ # One ONNX per species +│ ├── tomato-int8.onnx +│ ├── acorn-squash-int8.onnx +│ └── ... +├── disease-heads-list.json # Maps species ID → ONNX file path +├── ood-detector.pkl # Mahalanobis parameters for OOD +└── onnx-metadata.json # Input/output shapes, versions +``` + +## Steps + +### 3.1 Export backbone + species head as single ONNX + +```python +import torch +import onnx +from pathlib import Path + +model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt") +model.eval() + +# Export end-to-end species model (backbone + species head) +dummy = torch.randn(1, 3, 224, 224) +torch.onnx.export( + model, # Combined forward: backbone + species_head + dummy, + "public/models/swin-species.onnx", + input_names=["input"], + output_names=["species_logits", "embedding"], + dynamic_axes={ + "input": {0: "batch_size"}, + "species_logits": {0: "batch_size"}, + "embedding": {0: "batch_size"}, + }, + opset_version=17, + do_constant_folding=True, +) +``` + +**Key**: Export the 768-dim `embedding` as a second output — the server needs it to route to the correct disease head. + +### 3.2 Export disease heads individually + +Each disease head is a simple `nn.Linear(768, N_diseases)`. Export as a mini-ONNX that takes the embedding and returns disease logits: + +```python +for species_name, head in model.disease_heads.items(): + dummy_embed = torch.randn(1, 768) + torch.onnx.export( + head, dummy_embed, + f"public/models/disease-heads/{species_name}.onnx", + input_names=["embedding"], + output_names=["disease_logits"], + dynamic_axes={"embedding": {0: "batch_size"}}, + opset_version=17, + ) +``` + +**Total**: ~320 small ONNX files, each ~50-200 KB. + +### 3.3 INT8 Quantization + +Use ONNX Runtime's quantization tooling: + +```python +from onnxruntime.quantization import quantize_dynamic, QuantType + +# Quantize species model +quantize_dynamic( + "public/models/swin-species.onnx", + "public/models/swin-species-int8.onnx", + weight_type=QuantType.QInt8, +) + +# Quantize disease heads (batch) +for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")): + quantize_dynamic( + str(onnx_path), + str(onnx_path.with_suffix("-int8.onnx")), + weight_type=QuantType.QInt8, + ) +``` + +**Accuracy impact**: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers. + +### 3.4 OOD Detector + +Train a Mahalanobis distance-based OOD detector on the training set embeddings: + +```python +import numpy as np +from scipy.spatial.distance import mahalanobis + +# Collect embeddings from training set +embeddings = [] +for batch in val_dataloader: + with torch.no_grad(): + _, emb = model(batch["image"]) + embeddings.append(emb.numpy()) +embeddings = np.vstack(embeddings) + +# Fit multivariate Gaussian +mean = np.mean(embeddings, axis=0) +cov = np.cov(embeddings, rowvar=False) +inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0])) + +# Save for inference +import pickle +with open("public/models/ood-detector.pkl", "wb") as f: + pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f) +``` + +The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD. + +### 3.5 Accuracy verification + +Before committing to ONNX, verify against PyTorch: + +```python +import onnxruntime as ort + +# Compare PyTorch vs ONNX outputs +pytorch_out = model(sample_image) +ort_out = ort.InferenceSession("swin-species-int8.onnx").run( + ["species_logits"], {"input": sample_image.numpy()} +) + +max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0])) +assert max_diff < 0.01, f"ONNX mismatch: {max_diff}" +``` + +## Edge Cases & Gotchas + +- **ONNX opset compatibility**: Some `timm` model ops (like `roll` in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19. +- **Dynamic axes**: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions. +- **Disease head routing**: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly. +- **Strix Halo ROCm + ONNX**: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both. +- **Disease head file count**: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference). + +## Verification + +- [ ] ONNX species model output matches PyTorch output (max diff < 0.01) +- [ ] INT8 accuracy within 1% of FP16 on val set (sample 10K images) +- [ ] ONNX model loads in ONNX Runtime without errors +- [ ] All 320+ disease heads export successfully +- [ ] OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision +- [ ] ONNX model size < 5MB (INT8) for species, < 200KB per disease head +- [ ] Inference on CPU (Strix Halo) < 200ms for species + disease combined diff --git a/tasks/hierarchical-model-upgrade/04-server-inference.md b/tasks/hierarchical-model-upgrade/04-server-inference.md new file mode 100644 index 0000000..5e6f34f --- /dev/null +++ b/tasks/hierarchical-model-upgrade/04-server-inference.md @@ -0,0 +1,211 @@ +# Phase 4 — Server Inference Pipeline + +**Blocked by**: Phase 3 (exported ONNX models) +**Blocks**: Phase 5 (hybrid integration) +**Est. time**: 2-3 days +**Machine**: Strix Halo (will serve inference in production) + +## Objective + +Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results. + +## Architecture + +``` +POST /api/identify + │ + ▼ +┌──────────────────────┐ +│ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor +│ (sharp + buffer) │ Normalize with ImageNet stats +└──────────┬───────────┘ + ▼ +┌──────────────────────┐ +│ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded) +│ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold +└──────────┬───────────┘ + ▼ +┌──────────────────────┐ +│ 3. Species Inference │ Run swin-species-int8.onnx +│ (ONNX Runtime) │ softmax over 320 species logits +│ │ Return top-1 species + embedding vector +└──────────┬───────────┘ + ▼ +┌──────────────────────┐ +│ 4. Disease Routing │ Look up disease head ONNX for predicted species +│ (species→head) │ Feed embedding through disease head +│ │ softmax over species-conditional disease logits +└──────────┬───────────┘ + ▼ +┌──────────────────────┐ +│ 5. Enrichment │ Map class indices → disease/plant objects from +│ (knowledge base) │ src/data/diseases.json and src/data/plants.json +│ │ Return top-K with treatment info +└──────────┬───────────┘ + ▼ + JSON Response +``` + +## File Structure + +``` +src/ +├── lib/ +│ ├── server/ +│ │ ├── inference-server.ts ← Main orchestration +│ │ ├── onnx-loader.ts ← ONNX runtime session manager +│ │ ├── ood-detector.ts ← Mahalanobis OOD detection +│ │ ├── species-classifier.ts ← Species ONNX inference +│ │ ├── disease-classifier.ts ← Disease head routing + inference +│ │ └── image-preprocessor.ts ← Sharp-based preprocessing +│ └── ml/ +│ ├── inference.ts ← Existing browser inference (kept as-is) +│ └── ... +└── app/ + ├── api/ + │ ├── identify/route.ts ← Existing endpoint (keep for backward compat) + │ └── identify-v2/route.ts ← New server-side endpoint + └── ... +``` + +## Key Implementation Details + +### 4.1 ONNX Session Manager + +```typescript +// src/lib/server/onnx-loader.ts +import ort from "onnxruntime-node"; + +const sessions = new Map(); + +export async function getOrCreateSession(path: string): Promise { + if (!sessions.has(path)) { + sessions.set( + path, + await ort.InferenceSession.create(path, { + executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo + graphOptimizationLevel: "all", + }), + ); + } + return sessions.get(path)!; +} +``` + +**Execution provider**: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part). + +### 4.2 Lazy Loading Strategy + +Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them: + +```typescript +const diseaseHeadCache = new Map(); + +async function getDiseaseHead(speciesName: string) { + if (!diseaseHeadCache.has(speciesName)) { + const path = `public/models/disease-heads/${speciesName}-int8.onnx`; + diseaseHeadCache.set(speciesName, await createSession(path)); + } + return diseaseHeadCache.get(speciesName)!; +} +``` + +### 4.3 Main Inference Pipeline + +```typescript +// src/lib/server/inference-server.ts +export async function serverIdentify(imageBuffer: Buffer): Promise { + const start = performance.now(); + + // 1. Preprocess + const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224] + + // 2. OOD detection (quick, using embedding from species model) + const oodResult = await oodDetect(tensor); + if (!oodResult.isPlant) { + return { + error: "No plant detected", + confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold, + inferenceTimeMs: Math.round(performance.now() - start), + }; + } + + // 3. Species inference + const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx"); + const speciesOutput = await speciesSession.run({ + input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]), + }); + const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array); + const speciesProbs = softmax(speciesLogits); + const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0]; + const embedding = speciesOutput.embedding.data as Float32Array; + + // 4. Disease inference (routed by species) + const speciesName = speciesIndex[topSpeciesIdx]; + const diseaseSession = await getDiseaseHead(speciesName); + const diseaseOutput = await diseaseSession.run({ + embedding: new ort.Tensor("float32", embedding, [1, 768]), + }); + const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array); + const diseaseProbs = softmax(diseaseLogits); + const topDiseases = topK(diseaseProbs, 5); + + // 5. Enrichment + const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases); + + return { + species: { id: speciesName, confidence: topSpeciesProb }, + diseases: enriched, + oodScore: oodResult.mahalanobisDistance, + inferenceTimeMs: Math.round(performance.now() - start), + }; +} +``` + +### 4.4 API Route + +```typescript +// src/app/api/identify-v2/route.ts +export async function POST(req: Request) { + const formData = await req.formData(); + const image = formData.get("image") as File; + + if (!image || !image.type.startsWith("image/")) { + return Response.json({ error: "Invalid image" }, { status: 400 }); + } + + const buffer = Buffer.from(await image.arrayBuffer()); + const result = await serverIdentify(buffer); + + return Response.json(result); +} +``` + +### 4.5 Caching Strategy + +- **Model sessions**: Cache ONNX sessions in memory (warm on first request per deployment) +- **Disease heads**: Cache top-50 most common species' disease heads (LRU eviction) +- **Image preprocessing results**: Do NOT cache — each image is unique +- **Response caching**: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image) + +## Edge Cases & Gotchas + +- **Cold start latency**: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot. +- **Disease head not found**: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result. +- **Large images**: Client may upload 12MP photos. Resize to 224×224 _before_ feeding to ONNX (sharp is fast for this). Set a 10MB upload limit. +- **Concurrent requests**: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent `run()` calls). +- **Memory**: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8). +- **Error handling**: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode. + +## Verification + +- [ ] `POST /api/identify-v2` returns valid JSON with species + disease predictions +- [ ] Cold start (first ever request): < 3 seconds (model loading) +- [ ] Warm requests: < 200ms total (OOD + species + disease + enrichment) +- [ ] OOD detection correctly rejects non-plant images (rocks, buildings, animals) +- [ ] OOD detection correctly accepts plant images (false rejection rate < 1%) +- [ ] All 320+ species → disease head routes resolve correctly +- [ ] Large image (12MP) → preprocessed to 224×224 without OOM +- [ ] Concurrent 10 requests handled without errors or slowdown +- [ ] Degraded mode works if ONNX model fails (falls back to existing TF.js) +- [ ] Health endpoint reports model status, last inference time, error count diff --git a/tasks/hierarchical-model-upgrade/05-browser-hybrid.md b/tasks/hierarchical-model-upgrade/05-browser-hybrid.md new file mode 100644 index 0000000..8ee45eb --- /dev/null +++ b/tasks/hierarchical-model-upgrade/05-browser-hybrid.md @@ -0,0 +1,208 @@ +# Phase 5 — Browser Model & Hybrid Integration + +**Blocked by**: Phase 4 (server inference pipeline) +**Est. time**: 2-3 days +**Machine**: Any (development on Strix Halo or M3 Pro) + +## Objective + +Train a lightweight browser-compatible model (TF.js) and implement the hybrid routing logic: fast first pass in-browser, server fallback when confidence is low. + +## Hybrid Flow + +``` +User uploads image + │ + ▼ +┌──────────────────────┐ +│ Browser: │ +│ EfficientNet-Lite │ ← ~5MB TF.js model in browser +│ (TF.js) │ Predicts species + top-5 diseases +│ │ +│ Species confidence? │ +│ ┌────┴────┐ │ +│ │ ≥90% │ <90% │ +│ └────┬────┘ │ +│ │ │ +│ Show result │ +│ (instant) │ │ +└────────────┼────────┘ + │ (background if >90%, + │ foreground if <90%) + ▼ +┌──────────────────────┐ +│ Server: │ +│ Full Swin-Tiny │ ← Only when browser is uncertain +│ (ONNX Runtime) │ or user requests "detailed analysis" +│ │ +│ Returns enriched │ +│ results with full │ +│ treatment info │ +└──────────────────────┘ +``` + +## Steps + +### 5.1 Train lightweight browser model + +Use the hierarchical training data to train a **EfficientNet-Lite0** model that outputs both species and disease predictions: + +```python +import timm +import tensorflow as tf # For TF.js export + +# Train in PyTorch first (for accuracy), then convert +model = timm.create_model('efficientnet_lite0', pretrained=True) +# Add: species head (320) + disease head (11,499 flat) +# Or use hierarchical with just top-50 diseases per species + +# Training: 10 epochs frozen backbone, 10 epochs fine-tune +# Target: <5MB model size, runs in <100ms on mobile device +``` + +**Export to TF.js**: + +```bash +# Convert PyTorch → ONNX → TF.js +python -m tf2onnx.convert --pytorch-model browser_model.pt --output browser_model.onnx +tensorflowjs_converter --input_format=tf_saved_model browser_model/ browser_tfjs/ +``` + +**Model size target**: < 5MB (EfficientNet-Lite0 is ~4.7MB with INT8 quantization). + +### 5.2 Browser inference integration + +```typescript +// src/lib/ml/inference.ts — Updated with hybrid routing + +export type InferenceSource = "browser" | "server"; +export type InferenceMode = "quick" | "detailed"; + +export async function identifyPlant( + image: HTMLImageElement | File, + mode: InferenceMode = "quick", +): Promise { + // 1. Run browser model (always, it's fast) + const browserResult = await runBrowserInference(image); + + // 2. Decide: is this confident enough? + if (mode === "quick" && browserResult.topConfidence >= 0.9) { + // Browser alone is sufficient + return { + ...browserResult, + source: "browser", + inferenceTimeMs: browserResult.inferenceTimeMs, + }; + } + + // 3. Fall back to server for detailed analysis + const serverResult = await runServerInference(image); + + return { + ...serverResult, + source: "server", + browserConfidence: browserResult.topConfidence, + serverConfidence: serverResult.topConfidence, + }; +} + +async function runBrowserInference(image: HTMLImageElement): Promise { + const model = await getBrowserModel(); // Lazy load EfficientNet-Lite + const tensor = await preprocessBrowser(image); // TF.js preprocessing + const output = await model.predict(tensor); + return parseOutput(output); +} +``` + +### 5.3 UI integration + +```typescript +// src/components/ImageUpload.tsx — Updated + +function ImageUpload() { + const [result, setResult] = useState(null); + const [mode, setMode] = useState('quick'); + const [source, setSource] = useState(null); + + async function handleUpload(image: File) { + // Run browser model (instant) + const browserResult = await identifyPlant(image, 'quick'); + setResult(browserResult); + setSource(browserResult.source); + + // If server was called in background, show loading indicator + if (browserResult.source === 'server') { + // Show "Getting detailed analysis..." spinner + } + } + + return ( +
+ + {result && ( +
+ + +
+ )} +
+ ); +} +``` + +### 5.4 User-facing indication + +Show a subtle badge indicating which model made the prediction: + +| Source | Badge | UX | +| ------------------- | -------------------- | ------------------------------------- | +| Browser (high conf) | ✅ Instant ID | Green badge, "Analyzed on device" | +| Server (full model) | 🧠 Detailed Analysis | Blue badge, "Deep analysis" | +| Server (fallback) | 🔄 Upgraded | Yellow badge, "Upgraded for accuracy" | + +### 5.5 Progressive enhancement + +The system should degrade gracefully: + +| Scenario | Behavior | +| ---------------------------------- | --------------------------------------------------------------------- | +| Offline | Browser model only (may be less accurate for unusual diseases) | +| Slow network | Browser model shows results immediately, server updates in background | +| Server down | Browser model alone, with note: "Limited to quick analysis" | +| New disease (not in browser model) | Server model handles it, browser shows "could be unusual" | +| No camera / file | Error message, "Upload an image to identify" | + +## Edge Cases & Gotchas + +- **Model loading race**: If the browser model hasn't loaded yet, show a loading spinner rather than falling through to server. Lazy-load the model on page mount. +- **Discrepancy between browser and server**: If browser and server disagree on the top prediction, show both with confidence bars. The server model is authoritative. +- **Retina / high-DPI images**: TF.js may handle these differently from ONNX. Ensure preprocessing (resize, normalize) produces identical tensors. +- **Cache busting**: When the model is updated, increment a version hash in the URL to avoid stale cached models. +- **Memory**: EfficientNet-Lite takes ~5MB in memory. Older phones may struggle; add a cleanup step after inference (`model.dispose()`). + +## Performance Targets + +| Metric | Target | +| ------------------------------- | -------------------------------- | +| Browser model load time (warm) | < 1s | +| Browser model inference | < 100ms | +| Server model inference (warm) | < 200ms | +| Hybrid fast path (browser only) | < 200ms total | +| Hybrid server path | < 1.5s total (including network) | +| Model file size (browser) | < 5MB | + +## Verification + +- [ ] Browser model loads in Chrome, Firefox, Safari (desktop + mobile) +- [ ] Browser model inference completes in < 100ms on mid-range phone +- [ ] Hybrid routing works: conf ≥90% → browser result, conf <90% → server result +- [ ] Server fallback fires within 200ms of browser model completing +- [ ] UI shows source badge ("Instant ID" vs "Deep Analysis") +- [ ] Offline mode: browser model works without network +- [ ] Server degraded: system still works with browser model only +- [ ] No memory leaks on repeated inferences (10+ images in succession) +- [ ] Identical image produces same top prediction on browser and server (within margin) +- [ ] All existing tests pass with hybrid pipeline diff --git a/tasks/hierarchical-model-upgrade/README.md b/tasks/hierarchical-model-upgrade/README.md new file mode 100644 index 0000000..f08581d --- /dev/null +++ b/tasks/hierarchical-model-upgrade/README.md @@ -0,0 +1,63 @@ +# Hierarchical Model Architecture Upgrade + +**Scale**: 1.47M images across 11,499 disease-plant classes +**Goal**: Replace flat MobileNetV2 (38-class PlantVillage) with hierarchical Swin-Tiny (species → disease) +**Deployment**: Hybrid — lightweight browser model (TF.js) + full server model (ONNX Runtime) + +## Hardware + +| Machine | Role | Specs | +| -------------- | ------------------------------- | ---------------------------------------- | +| **Strix Halo** | Primary training + inference | AI 395+ MAX (ROCm), 128GB unified memory | +| **RTX 3090** | Secondary training / CUDA path | 24GB VRAM | +| **M3 Pro** | Development only (work machine) | — | + +**Key advantage**: Strix Halo's 128GB unified memory allows loading the entire 1.5M image dataset into RAM and training with extremely large effective batch sizes — the GPU accesses the full 128GB pool, no VRAM ceiling. + +## Status Legend + +``` +[ ] not started [~] in progress [x] done [-] skipped +``` + +## Task Map + +``` +Phase 1 ──→ Phase 2 ──→ Phase 3 ──→ Phase 4 ──→ Phase 5 +Dataset Model Model Server Integration +Reorg Training Export Inference + Testing + & Quant. Pipeline +``` + +## Phases + +- [ ] [Phase 1 — Dataset Reorganization](01-dataset-reorganization.md) + Parse 11,499 flat directories into hierarchical species→disease structure, create train/val splits, build species index. +- [ ] [Phase 2 — Hierarchical Model Training](02-hierarchical-training.md) + Train Swin-Tiny backbone + species head + disease heads using PyTorch + ROCm on Strix Halo. +- [ ] [Phase 3 — ONNX Export & Quantization](03-export-quantization.md) + Export trained models to ONNX, apply INT8 quantization, verify accuracy. +- [ ] [Phase 4 — Server Inference Pipeline](04-server-inference.md) + Build server-side inference API with ONNX Runtime, OOD detection, species routing. +- [ ] [Phase 5 — Browser Model & Hybrid Integration](05-browser-hybrid.md) + Lightweight TF.js model for client, hybrid confidence-based routing, full integration. + +## Dependencies + +``` +01 (dataset) ──→ 02 (training) ──→ 03 (export) ──→ 04 (server) + │ + └──→ 05 (browser + hybrid) +``` + +## Exit Criteria + +- [ ] Species classifier achieves ≥95% top-1 accuracy on held-out val set +- [ ] Disease classifiers achieve ≥90% top-3 accuracy per species +- [ ] ONNX INT8 models infer in <200ms on CPU, <50ms on GPU +- [ ] Browser TF.js model loads and runs in <100ms on mid-range devices +- [ ] Hybrid routing works: high-confidence results served instantly from browser +- [ ] Server fallback fires automatically when browser confidence is low +- [ ] OOD detection rejects non-plant images with ≥99% precision +- [ ] Full integration: upload → result in <500ms (browser) or <1s (server) +- [ ] Existing app functionality preserved (all routes, pages, API endpoints)