task to get this here done
This commit is contained in:
@@ -22,6 +22,7 @@
|
|||||||
import "dotenv/config";
|
import "dotenv/config";
|
||||||
import { readFileSync, readdirSync, writeFileSync, existsSync, mkdirSync } from "fs";
|
import { readFileSync, readdirSync, writeFileSync, existsSync, mkdirSync } from "fs";
|
||||||
import { resolve, extname } from "path";
|
import { resolve, extname } from "path";
|
||||||
|
import { Agent, setGlobalDispatcher } from "undici";
|
||||||
|
|
||||||
// Load .env.development for DB creds
|
// Load .env.development for DB creds
|
||||||
const envPath = resolve(__dirname, "../.env.development");
|
const envPath = resolve(__dirname, "../.env.development");
|
||||||
@@ -41,7 +42,7 @@ try {
|
|||||||
} catch {}
|
} catch {}
|
||||||
|
|
||||||
import { getDb, closeDb } from "@/lib/db/index";
|
import { getDb, closeDb } from "@/lib/db/index";
|
||||||
import { diseases } from "@/lib/db/schema";
|
import { plants, diseases } from "@/lib/db/schema";
|
||||||
|
|
||||||
// ─── Config ─────────────────────────────────────────────────────────────────
|
// ─── Config ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -49,17 +50,18 @@ const DATASET_DIR = resolve(__dirname, "../data/dataset");
|
|||||||
const SEEN_CACHE_FILE = resolve(DATASET_DIR, ".fill-seen-urls.json");
|
const SEEN_CACHE_FILE = resolve(DATASET_DIR, ".fill-seen-urls.json");
|
||||||
|
|
||||||
/** Target images per disease */
|
/** Target images per disease */
|
||||||
const TARGET_PER_DISEASE = 200;
|
const TARGET_PER_DISEASE = 100;
|
||||||
|
|
||||||
/** Target images for the "healthy" class */
|
/** Target images per plant for the "healthy" class */
|
||||||
const TARGET_HEALTHY = 400;
|
const TARGET_HEALTHY_PER_PLANT = 400;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* How many diseases to process in parallel.
|
* How many diseases to process in parallel.
|
||||||
* Each disease is I/O-bound (HTTP requests), so high concurrency is safe.
|
* Reduced from 50 to 5 to prevent overwhelming undici's connection pool.
|
||||||
* The global DDG rate limiter prevents us from overwhelming DuckDuckGo.
|
* Each disease fires multiple concurrent requests, so high concurrency causes
|
||||||
|
* connection exhaustion and undici crashes.
|
||||||
*/
|
*/
|
||||||
const DISEASE_CONCURRENCY = 50;
|
const DISEASE_CONCURRENCY = 5;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Max DDG requests per second (shared across all concurrent diseases).
|
* Max DDG requests per second (shared across all concurrent diseases).
|
||||||
@@ -70,8 +72,11 @@ const DISEASE_CONCURRENCY = 50;
|
|||||||
*/
|
*/
|
||||||
const DDG_RATE_LIMIT_RPS = 6;
|
const DDG_RATE_LIMIT_RPS = 6;
|
||||||
|
|
||||||
/** Max concurrent image downloads per disease */
|
/** Max concurrent image downloads per disease.
|
||||||
const CONCURRENT_DOWNLOADS = 50;
|
* Reduced from 50 to 10 to prevent connection pool exhaustion.
|
||||||
|
* With 5 diseases × 10 downloads = 50 concurrent downloads (manageable).
|
||||||
|
*/
|
||||||
|
const CONCURRENT_DOWNLOADS = 10;
|
||||||
|
|
||||||
/** Minimum image size in bytes to accept */
|
/** Minimum image size in bytes to accept */
|
||||||
const MIN_IMAGE_SIZE = 10_000; // 10KB
|
const MIN_IMAGE_SIZE = 10_000; // 10KB
|
||||||
@@ -86,9 +91,16 @@ const ALLOWED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"];
|
|||||||
const UA =
|
const UA =
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_4 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.4 Mobile/15E148 Safari/604.1";
|
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_4 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.4 Mobile/15E148 Safari/604.1";
|
||||||
|
|
||||||
/** Healthy class directory name */
|
/** Healthy class parent directory name */
|
||||||
const HEALTHY_CLASS = "healthy";
|
const HEALTHY_CLASS = "healthy";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* How many plants to process in parallel for healthy image collection.
|
||||||
|
* Each plant runs the same fillClass pipeline (3 sources), so keep this
|
||||||
|
* modest to avoid connection pool exhaustion.
|
||||||
|
*/
|
||||||
|
const MAX_HEALTHY_CONCURRENCY = 5;
|
||||||
|
|
||||||
/** How often (in diseases processed) to flush the seen-URLs cache to disk */
|
/** How often (in diseases processed) to flush the seen-URLs cache to disk */
|
||||||
const SEEN_CACHE_FLUSH_INTERVAL = 20;
|
const SEEN_CACHE_FLUSH_INTERVAL = 20;
|
||||||
|
|
||||||
@@ -98,9 +110,6 @@ const SEEN_CACHE_FLUSH_INTERVAL = 20;
|
|||||||
* the seen-URLs cache accumulates across runs. */
|
* the seen-URLs cache accumulates across runs. */
|
||||||
const MAX_DDG_PAGES = 5;
|
const MAX_DDG_PAGES = 5;
|
||||||
|
|
||||||
/** Healthy source queries limit */
|
|
||||||
const MAX_HEALTHY_QUERIES = 20;
|
|
||||||
|
|
||||||
// ─── Types ──────────────────────────────────────────────────────────────────
|
// ─── Types ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
interface DuckDuckGoImageResult {
|
interface DuckDuckGoImageResult {
|
||||||
@@ -169,6 +178,45 @@ const ddgLimiter = new TokenBucket(DDG_RATE_LIMIT_RPS);
|
|||||||
|
|
||||||
// ─── Helpers ────────────────────────────────────────────────────────────────
|
// ─── Helpers ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom undici agent with connection pooling to prevent overwhelming the network stack.
|
||||||
|
* Limits concurrent connections per host to prevent undici crashes (the "onResponseError" bug).
|
||||||
|
* setGlobalDispatcher() makes ALL fetch() calls use this agent automatically.
|
||||||
|
*/
|
||||||
|
const fetchAgent = new Agent({
|
||||||
|
connections: 50, // Max 50 concurrent connections total
|
||||||
|
connect: {
|
||||||
|
timeout: 10_000, // 10s connection timeout
|
||||||
|
},
|
||||||
|
bodyTimeout: 30_000, // 30s body timeout
|
||||||
|
headersTimeout: 15_000, // 15s headers timeout
|
||||||
|
keepAliveTimeout: 30_000, // 30s keep-alive
|
||||||
|
keepAliveMaxTimeout: 60_000, // 60s max keep-alive
|
||||||
|
});
|
||||||
|
setGlobalDispatcher(fetchAgent);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper around global fetch() with retry + exponential backoff for transient errors.
|
||||||
|
* The connection pooling is handled by the global undici dispatcher (set above).
|
||||||
|
*/
|
||||||
|
async function safeFetch(url: string | URL, init?: RequestInit, retries = 2): Promise<Response> {
|
||||||
|
let lastError: Error | undefined;
|
||||||
|
for (let attempt = 0; attempt <= retries; attempt++) {
|
||||||
|
try {
|
||||||
|
return await fetch(url, init);
|
||||||
|
} catch (err: unknown) {
|
||||||
|
lastError = err instanceof Error ? err : new Error(String(err));
|
||||||
|
// Don't retry on abort/timeout (user-initiated)
|
||||||
|
if (err instanceof Error && err.name === "AbortError") throw err;
|
||||||
|
if (attempt < retries) {
|
||||||
|
// Exponential backoff: 500ms, 2000ms
|
||||||
|
await sleep(500 * (attempt + 1) * (attempt + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw lastError ?? new Error(`fetch failed after ${retries + 1} attempts`);
|
||||||
|
}
|
||||||
|
|
||||||
function sleep(ms: number): Promise<void> {
|
function sleep(ms: number): Promise<void> {
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
}
|
}
|
||||||
@@ -240,7 +288,7 @@ async function getVqdToken(query: string): Promise<string> {
|
|||||||
|
|
||||||
const url = `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`;
|
const url = `https://duckduckgo.com/?q=${encodeURIComponent(query)}&t=h_&iax=images&ia=images`;
|
||||||
|
|
||||||
const res = await fetch(url, {
|
const res = await safeFetch(url, {
|
||||||
headers: { "User-Agent": UA, Accept: "text/html" },
|
headers: { "User-Agent": UA, Accept: "text/html" },
|
||||||
signal: AbortSignal.timeout(15_000),
|
signal: AbortSignal.timeout(15_000),
|
||||||
});
|
});
|
||||||
@@ -267,7 +315,7 @@ async function searchImagesDuckDuckGo(
|
|||||||
query,
|
query,
|
||||||
)}&vqd=${vqd}&o=json&p=${page}&f=,,,`;
|
)}&vqd=${vqd}&o=json&p=${page}&f=,,,`;
|
||||||
|
|
||||||
const res = await fetch(url, {
|
const res = await safeFetch(url, {
|
||||||
headers: {
|
headers: {
|
||||||
"User-Agent": UA,
|
"User-Agent": UA,
|
||||||
Accept: "application/json",
|
Accept: "application/json",
|
||||||
@@ -289,7 +337,7 @@ async function searchImagesDuckDuckGo(
|
|||||||
const freshVqd = await getVqdToken(query);
|
const freshVqd = await getVqdToken(query);
|
||||||
await ddgLimiter.acquire();
|
await ddgLimiter.acquire();
|
||||||
const retryUrl = url.replace(/vqd=[^&]+/, `vqd=${freshVqd}`);
|
const retryUrl = url.replace(/vqd=[^&]+/, `vqd=${freshVqd}`);
|
||||||
const retryRes = await fetch(retryUrl, {
|
const retryRes = await safeFetch(retryUrl, {
|
||||||
headers: {
|
headers: {
|
||||||
"User-Agent": UA,
|
"User-Agent": UA,
|
||||||
Accept: "application/json",
|
Accept: "application/json",
|
||||||
@@ -410,7 +458,7 @@ async function searchImagesInaturalist(
|
|||||||
`&order_by=observed_on&order=desc`;
|
`&order_by=observed_on&order=desc`;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await fetch(apiUrl, {
|
const res = await safeFetch(apiUrl, {
|
||||||
headers: { "User-Agent": UA, Accept: "application/json" },
|
headers: { "User-Agent": UA, Accept: "application/json" },
|
||||||
signal: AbortSignal.timeout(15_000),
|
signal: AbortSignal.timeout(15_000),
|
||||||
});
|
});
|
||||||
@@ -462,7 +510,7 @@ async function searchImagesCommons(
|
|||||||
const url = `https://commons.wikimedia.org/w/api.php?${params}`;
|
const url = `https://commons.wikimedia.org/w/api.php?${params}`;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await fetch(url, {
|
const res = await safeFetch(url, {
|
||||||
headers: { "User-Agent": UA },
|
headers: { "User-Agent": UA },
|
||||||
signal: AbortSignal.timeout(10_000),
|
signal: AbortSignal.timeout(10_000),
|
||||||
});
|
});
|
||||||
@@ -500,7 +548,7 @@ async function searchImagesCommons(
|
|||||||
|
|
||||||
async function downloadImage(url: string, destPath: string): Promise<boolean> {
|
async function downloadImage(url: string, destPath: string): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const res = await fetch(url, {
|
const res = await safeFetch(url, {
|
||||||
headers: { "User-Agent": UA, Accept: "image/webp,image/png,image/jpeg,*/*" },
|
headers: { "User-Agent": UA, Accept: "image/webp,image/png,image/jpeg,*/*" },
|
||||||
signal: AbortSignal.timeout(8_000),
|
signal: AbortSignal.timeout(8_000),
|
||||||
});
|
});
|
||||||
@@ -667,16 +715,16 @@ async function fillClass(
|
|||||||
interface ScanResult {
|
interface ScanResult {
|
||||||
/** Disease id → how many images currently on disk */
|
/** Disease id → how many images currently on disk */
|
||||||
diseaseCounts: Map<string, number>;
|
diseaseCounts: Map<string, number>;
|
||||||
/** How many healthy images on disk */
|
/** Plant id → how many healthy images currently on disk */
|
||||||
healthyCount: number;
|
healthyCounts: Map<string, number>;
|
||||||
}
|
}
|
||||||
|
|
||||||
function scanDataset(): ScanResult {
|
function scanDataset(): ScanResult {
|
||||||
const diseaseCounts = new Map<string, number>();
|
const diseaseCounts = new Map<string, number>();
|
||||||
let healthyCount = 0;
|
const healthyCounts = new Map<string, number>();
|
||||||
|
|
||||||
if (!existsSync(DATASET_DIR)) {
|
if (!existsSync(DATASET_DIR)) {
|
||||||
return { diseaseCounts, healthyCount: 0 };
|
return { diseaseCounts, healthyCounts };
|
||||||
}
|
}
|
||||||
|
|
||||||
const entries = readdirSync(DATASET_DIR, { withFileTypes: true });
|
const entries = readdirSync(DATASET_DIR, { withFileTypes: true });
|
||||||
@@ -686,7 +734,16 @@ function scanDataset(): ScanResult {
|
|||||||
if (entry.name.startsWith(".")) continue;
|
if (entry.name.startsWith(".")) continue;
|
||||||
|
|
||||||
if (entry.name === HEALTHY_CLASS) {
|
if (entry.name === HEALTHY_CLASS) {
|
||||||
healthyCount = countImagesInDir(resolve(DATASET_DIR, entry.name));
|
// Scan per-plant subdirectories under healthy/
|
||||||
|
const healthyDir = resolve(DATASET_DIR, entry.name);
|
||||||
|
const plantDirs = readdirSync(healthyDir, { withFileTypes: true });
|
||||||
|
for (const pd of plantDirs) {
|
||||||
|
if (!pd.isDirectory() || pd.name.startsWith(".")) continue;
|
||||||
|
const count = countImagesInDir(resolve(healthyDir, pd.name));
|
||||||
|
if (count > 0) {
|
||||||
|
healthyCounts.set(pd.name, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
const count = countImagesInDir(resolve(DATASET_DIR, entry.name));
|
const count = countImagesInDir(resolve(DATASET_DIR, entry.name));
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
@@ -695,7 +752,7 @@ function scanDataset(): ScanResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { diseaseCounts, healthyCount };
|
return { diseaseCounts, healthyCounts };
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── CLI Flags ──────────────────────────────────────────────────────────────
|
// ─── CLI Flags ──────────────────────────────────────────────────────────────
|
||||||
@@ -722,11 +779,14 @@ async function main() {
|
|||||||
|
|
||||||
// ── Step 1: Scan what we already have ────────────────────────────────────
|
// ── Step 1: Scan what we already have ────────────────────────────────────
|
||||||
console.log("\nScanning existing dataset...");
|
console.log("\nScanning existing dataset...");
|
||||||
const { diseaseCounts, healthyCount } = scanDataset();
|
const { diseaseCounts, healthyCounts } = scanDataset();
|
||||||
console.log(` Found ${diseaseCounts.size} disease directories, ${healthyCount} healthy images`);
|
const totalHealthyImages = [...healthyCounts.values()].reduce((s, c) => s + c, 0);
|
||||||
|
console.log(
|
||||||
|
` Found ${diseaseCounts.size} disease directories, ${healthyCounts.size} plants with healthy images (${totalHealthyImages} total)`,
|
||||||
|
);
|
||||||
|
|
||||||
// ── Step 2: Load disease info from DB ────────────────────────────────────
|
// ── Step 2: Load disease info and plant info from DB ─────────────────────
|
||||||
console.log("\nLoading disease info from database...");
|
console.log("\nLoading data from database...");
|
||||||
const db = getDb();
|
const db = getDb();
|
||||||
|
|
||||||
const allDiseases = await db
|
const allDiseases = await db
|
||||||
@@ -746,6 +806,15 @@ async function main() {
|
|||||||
}
|
}
|
||||||
console.log(` Loaded ${diseaseInfo.size} unique diseases from DB`);
|
console.log(` Loaded ${diseaseInfo.size} unique diseases from DB`);
|
||||||
|
|
||||||
|
// Load all plants from DB for healthy class generation
|
||||||
|
const allPlants = await db
|
||||||
|
.select({
|
||||||
|
id: plants.id,
|
||||||
|
commonName: plants.commonName,
|
||||||
|
})
|
||||||
|
.from(plants);
|
||||||
|
console.log(` Loaded ${allPlants.length} plants from DB`);
|
||||||
|
|
||||||
// ── Step 3: Build deficit list ──────────────────────────────────────────
|
// ── Step 3: Build deficit list ──────────────────────────────────────────
|
||||||
const deficits: DiseaseInfo[] = [];
|
const deficits: DiseaseInfo[] = [];
|
||||||
|
|
||||||
@@ -764,14 +833,24 @@ async function main() {
|
|||||||
// direction when the front of the queue keeps hitting dead URLs)
|
// direction when the front of the queue keeps hitting dead URLs)
|
||||||
if (flags.reverse) deficits.reverse();
|
if (flags.reverse) deficits.reverse();
|
||||||
|
|
||||||
const healthyDeficit = TARGET_HEALTHY - healthyCount;
|
// Build per-plant healthy deficits
|
||||||
|
const healthyDeficits: Array<{ id: string; have: number; needed: number }> = [];
|
||||||
|
for (const plant of allPlants) {
|
||||||
|
const have = healthyCounts.get(plant.id) ?? 0;
|
||||||
|
const needed = TARGET_HEALTHY_PER_PLANT - have;
|
||||||
|
if (needed > 0) {
|
||||||
|
healthyDeficits.push({ id: plant.id, have, needed });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const totalHealthyDeficit = healthyDeficits.reduce((s, d) => s + d.needed, 0);
|
||||||
|
|
||||||
console.log(`\n${"=".repeat(60)}`);
|
console.log(`\n${"=".repeat(60)}`);
|
||||||
console.log("DEFICIT REPORT");
|
console.log("DEFICIT REPORT");
|
||||||
console.log(`${"=".repeat(60)}`);
|
console.log(`${"=".repeat(60)}`);
|
||||||
console.log(` Diseases needing images: ${deficits.length}/${diseaseInfo.size}`);
|
console.log(` Diseases needing images: ${deficits.length}/${diseaseInfo.size}`);
|
||||||
console.log(` Total images missing: ${deficits.reduce((s, d) => s + d.needed, 0)}`);
|
console.log(` Total images missing: ${deficits.reduce((s, d) => s + d.needed, 0)}`);
|
||||||
console.log(` Healthy deficit: ${Math.max(0, healthyDeficit)}`);
|
console.log(` Plants needing healthy: ${healthyDeficits.length}/${allPlants.length}`);
|
||||||
|
console.log(` Total healthy images missing: ${totalHealthyDeficit}`);
|
||||||
console.log(` Parallelism: ${DISEASE_CONCURRENCY} diseases at once`);
|
console.log(` Parallelism: ${DISEASE_CONCURRENCY} diseases at once`);
|
||||||
console.log(` DDG rate limit: ${DDG_RATE_LIMIT_RPS} req/s (shared)`);
|
console.log(` DDG rate limit: ${DDG_RATE_LIMIT_RPS} req/s (shared)`);
|
||||||
console.log(
|
console.log(
|
||||||
@@ -779,7 +858,7 @@ async function main() {
|
|||||||
);
|
);
|
||||||
console.log(`${"=".repeat(60)}`);
|
console.log(`${"=".repeat(60)}`);
|
||||||
|
|
||||||
if (deficits.length === 0 && healthyDeficit <= 0) {
|
if (deficits.length === 0 && healthyDeficits.length === 0) {
|
||||||
console.log("\n ✓ Nothing to do — all targets met!\n");
|
console.log("\n ✓ Nothing to do — all targets met!\n");
|
||||||
await closeDb();
|
await closeDb();
|
||||||
return;
|
return;
|
||||||
@@ -788,7 +867,6 @@ async function main() {
|
|||||||
// ── Step 4: Load seen-URLs cache ────────────────────────────────────────
|
// ── Step 4: Load seen-URLs cache ────────────────────────────────────────
|
||||||
const seenUrlsCache = loadSeenUrlsCache();
|
const seenUrlsCache = loadSeenUrlsCache();
|
||||||
let totalDownloaded = 0;
|
let totalDownloaded = 0;
|
||||||
let totalFailed = 0;
|
|
||||||
let diseasesProcessed = 0;
|
let diseasesProcessed = 0;
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
@@ -875,64 +953,78 @@ async function main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Step 6: Fill healthy deficit ────────────────────────────────────────
|
// ── Step 6: Fill healthy deficits per plant ────────────────────────────
|
||||||
if (healthyDeficit > 0) {
|
if (healthyDeficits.length > 0) {
|
||||||
console.log("\n" + "─".repeat(60));
|
console.log("\n" + "─".repeat(60));
|
||||||
console.log(`FILLING HEALTHY CLASS (target: ${TARGET_HEALTHY})`);
|
console.log(
|
||||||
|
`FILLING ${healthyDeficits.length} PLANTS\' HEALTHY CLASSES` +
|
||||||
|
` (target: ${TARGET_HEALTHY_PER_PLANT} each)`,
|
||||||
|
);
|
||||||
console.log("─".repeat(60));
|
console.log("─".repeat(60));
|
||||||
|
|
||||||
const healthyDir = resolve(DATASET_DIR, HEALTHY_CLASS);
|
let healthyProcessed = 0;
|
||||||
mkdirSync(healthyDir, { recursive: true });
|
|
||||||
|
|
||||||
// Collect all unique plants from the disease info
|
// Process in parallel batches
|
||||||
const allPlants = [...new Set(diseaseInfo.values())].map((d) => d.plantId);
|
for (let i = 0; i < healthyDeficits.length; i += MAX_HEALTHY_CONCURRENCY) {
|
||||||
const allHealthyQueries: string[] = [];
|
const batch = healthyDeficits.slice(i, i + MAX_HEALTHY_CONCURRENCY);
|
||||||
for (const plant of allPlants) {
|
const batchNum = Math.floor(i / MAX_HEALTHY_CONCURRENCY) + 1;
|
||||||
allHealthyQueries.push(...buildHealthyQueries(plant));
|
const totalBatches = Math.ceil(healthyDeficits.length / MAX_HEALTHY_CONCURRENCY);
|
||||||
}
|
|
||||||
|
|
||||||
const healthySeen = new Set<string>(seenUrlsCache[HEALTHY_CLASS] ?? []);
|
console.log(`\n[Batch ${batchNum}/${totalBatches}] Processing ${batch.length} plants...`);
|
||||||
const healthyNeeded = TARGET_HEALTHY - countImagesInDir(healthyDir);
|
|
||||||
|
|
||||||
// Run all 3 sources in parallel for the healthy class too
|
const STAGGER_MS = 200;
|
||||||
const [ddgUrls, inatUrls, commonsUrls] = await Promise.allSettled([
|
const batchResults = await Promise.allSettled(
|
||||||
collectImagesDuckDuckGo(
|
batch.map((p, idx) =>
|
||||||
allHealthyQueries.slice(0, MAX_HEALTHY_QUERIES),
|
(async () => {
|
||||||
healthyNeeded,
|
if (idx > 0) await sleep(idx * STAGGER_MS);
|
||||||
healthySeen,
|
|
||||||
),
|
|
||||||
searchImagesInaturalist(allHealthyQueries[0], healthyNeeded, healthySeen),
|
|
||||||
searchImagesCommons(allHealthyQueries[0], healthyNeeded, healthySeen),
|
|
||||||
]);
|
|
||||||
|
|
||||||
const allUrls: string[] = [];
|
const classDir = resolve(DATASET_DIR, HEALTHY_CLASS, p.id);
|
||||||
for (const settled of [ddgUrls, inatUrls, commonsUrls]) {
|
const queries = buildHealthyQueries(p.id);
|
||||||
if (settled.status === "fulfilled") {
|
const seen = new Set<string>();
|
||||||
allUrls.push(...settled.value.urls);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (allUrls.length > 0) {
|
console.log(` [healthy/${p.id}] have ${p.have}, need ${p.needed} more`);
|
||||||
console.log(`\n Downloading ${allUrls.length} healthy images...`);
|
|
||||||
const startIdx = countImagesInDir(healthyDir);
|
|
||||||
const { downloaded, failed } = await downloadBatch(allUrls, healthyDir, startIdx);
|
|
||||||
|
|
||||||
const newTotal = countImagesInDir(healthyDir);
|
const gained = await fillClass(`healthy/${p.id}`, queries, p.needed, classDir, seen);
|
||||||
const gained = newTotal - healthyCount;
|
|
||||||
totalDownloaded += gained;
|
|
||||||
totalFailed += failed;
|
|
||||||
|
|
||||||
console.log(
|
// Update seen-URLs cache for this plant's healthy images
|
||||||
` ${downloaded > 0 ? "✓" : "✗"} Got ${downloaded} images.` +
|
const cacheKey = `${HEALTHY_CLASS}/${p.id}`;
|
||||||
` Total healthy: ${newTotal}/${TARGET_HEALTHY} (${gained} new)`,
|
const existing = seenUrlsCache[cacheKey] ?? [];
|
||||||
|
const merged = [...new Set([...existing, ...Array.from(seen)])];
|
||||||
|
seenUrlsCache[cacheKey] = merged.slice(-500);
|
||||||
|
return gained;
|
||||||
|
})(),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
} else {
|
|
||||||
console.log(`\n ✗ No healthy images found`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update seen-URLs cache
|
for (const result of batchResults) {
|
||||||
seenUrlsCache[HEALTHY_CLASS] = Array.from(healthySeen);
|
if (result.status === "fulfilled") {
|
||||||
saveSeenUrlsCache(seenUrlsCache);
|
totalDownloaded += result.value;
|
||||||
|
} else {
|
||||||
|
console.error(` ✗ Plant healthy fill failed: ${result.reason}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
healthyProcessed += batch.length;
|
||||||
|
|
||||||
|
// Flush seen-URLs cache to disk periodically
|
||||||
|
if (
|
||||||
|
healthyProcessed % SEEN_CACHE_FLUSH_INTERVAL < batch.length ||
|
||||||
|
i + batch.length >= healthyDeficits.length
|
||||||
|
) {
|
||||||
|
saveSeenUrlsCache(seenUrlsCache);
|
||||||
|
}
|
||||||
|
|
||||||
|
const elapsed = Math.round((Date.now() - startTime) / 1000);
|
||||||
|
const rate = healthyProcessed / Math.max(1, elapsed);
|
||||||
|
const remaining = healthyDeficits.length - healthyProcessed;
|
||||||
|
const eta = remaining / Math.max(0.01, rate);
|
||||||
|
console.log(
|
||||||
|
` [Batch ${batchNum}/${totalBatches}] checkpoint — ` +
|
||||||
|
`${totalDownloaded} downloaded (healthy), ` +
|
||||||
|
`${healthyProcessed}/${healthyDeficits.length} plants (${rate.toFixed(1)}/s, ` +
|
||||||
|
`ETA: ${Math.round(eta)}s)`,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Summary ──────────────────────────────────────────────────────────────
|
// ── Summary ──────────────────────────────────────────────────────────────
|
||||||
@@ -946,16 +1038,21 @@ async function main() {
|
|||||||
const atTarget = [...finalScan.diseaseCounts.values()].filter(
|
const atTarget = [...finalScan.diseaseCounts.values()].filter(
|
||||||
(c) => c >= TARGET_PER_DISEASE,
|
(c) => c >= TARGET_PER_DISEASE,
|
||||||
).length;
|
).length;
|
||||||
|
const healthyAtTarget = [...finalScan.healthyCounts.values()].filter(
|
||||||
|
(c) => c >= TARGET_HEALTHY_PER_PLANT,
|
||||||
|
).length;
|
||||||
|
const totalHealthyFinal = [...finalScan.healthyCounts.values()].reduce((s, c) => s + c, 0);
|
||||||
|
|
||||||
console.log("\n" + "=".repeat(60));
|
console.log("\n" + "=".repeat(60));
|
||||||
console.log(" ✅ FILL COMPLETE");
|
console.log(" ✅ FILL COMPLETE");
|
||||||
console.log("=".repeat(60));
|
console.log("=".repeat(60));
|
||||||
console.log(` Time: ${hrs}h ${mins % 60}m`);
|
console.log(` Time: ${hrs}h ${mins % 60}m`);
|
||||||
console.log(` Diseases at target: ${atTarget}/${diseaseInfo.size}`);
|
console.log(` Diseases at target: ${atTarget}/${diseaseInfo.size}`);
|
||||||
console.log(` Total images: ${totalHave}`);
|
console.log(` Total disease images: ${totalHave}`);
|
||||||
console.log(` Healthy images: ${finalScan.healthyCount}/${TARGET_HEALTHY}`);
|
console.log(` Plants at healthy target: ${healthyAtTarget}/${allPlants.length}`);
|
||||||
console.log(` New downloads: ${totalDownloaded}`);
|
console.log(` Total healthy images: ${totalHealthyFinal}`);
|
||||||
console.log(` Dataset dir: ${DATASET_DIR}/`);
|
console.log(` New downloads: ${totalDownloaded}`);
|
||||||
|
console.log(` Dataset dir: ${DATASET_DIR}/`);
|
||||||
|
|
||||||
await closeDb();
|
await closeDb();
|
||||||
console.log("=".repeat(60));
|
console.log("=".repeat(60));
|
||||||
|
|||||||
169
tasks/hierarchical-model-upgrade/01-dataset-reorganization.md
Normal file
169
tasks/hierarchical-model-upgrade/01-dataset-reorganization.md
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# Phase 1 — Dataset Reorganization
|
||||||
|
|
||||||
|
**Blocked by**: Nothing
|
||||||
|
**Blocks**: Phase 2 (training)
|
||||||
|
**Est. time**: 2-3 days
|
||||||
|
**Machine**: Strix Halo (fast I/O, 128GB RAM for in-memory processing)
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Reorganize the flat directory structure (`data/dataset/plant-disease-name/`) into a proper hierarchical layout (`data/organized/species/disease/`) with train/val splits and metadata files.
|
||||||
|
|
||||||
|
## Current State
|
||||||
|
|
||||||
|
- `data/dataset/` — 11,499 flat directories, each named `{plant}-{disease}`
|
||||||
|
- Files mixed: `.jpg`, `.jpeg`, `.png`, `.webp` per directory
|
||||||
|
- Total: ~1.47M images, 64-244 images per class (well-balanced)
|
||||||
|
- **Total size: ~450 GB**
|
||||||
|
- **SSD available**: 8TB NVMe (7,300 MB/s read, 6,300 MB/s write — PCIe 5.0)
|
||||||
|
|
||||||
|
## Deliverables
|
||||||
|
|
||||||
|
```
|
||||||
|
data/
|
||||||
|
├── organized/
|
||||||
|
│ ├── train/ # 85% of images
|
||||||
|
│ │ ├── {species_1}/
|
||||||
|
│ │ │ ├── healthy/
|
||||||
|
│ │ │ ├── {disease_a}/
|
||||||
|
│ │ │ └── {disease_b}/
|
||||||
|
│ │ ├── {species_2}/
|
||||||
|
│ │ └── ...
|
||||||
|
│ ├── val/ # 15% of images
|
||||||
|
│ │ └── ... (mirrors train structure)
|
||||||
|
│ ├── species_index.json # Maps species → [disease IDs]
|
||||||
|
│ ├── class_hierarchy.json # Full mapping + metadata
|
||||||
|
│ └── dataset_stats.json # Counts per class, splits
|
||||||
|
```
|
||||||
|
|
||||||
|
## Steps
|
||||||
|
|
||||||
|
### 1.1 Parse directory names → (species, disease) pairs
|
||||||
|
|
||||||
|
**Problem**: Directory names like `acorn-squash-powdery-mildew` use an inconsistent separator (hyphen). Need to reliably split plant name from disease name.
|
||||||
|
|
||||||
|
**Approach**: Use `src/data/diseases.json` as ground truth. Try matching each directory name against known disease ID suffixes (sorted longest-first). The remainder is the plant name.
|
||||||
|
|
||||||
|
**Fallback**: For unmatched dirs, build a plant suffix list from `src/data/plants.json` and try prefix matching. Log any truly unmatched dirs for manual review.
|
||||||
|
|
||||||
|
**Script**: `scripts/organize-dataset.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Pseudocode for the matching algorithm:
|
||||||
|
disease_ids = sorted([d["id"] for d in diseases], key=len, reverse=True)
|
||||||
|
plant_names = [p["id"] for p in plants] # or extract from dir prefixes
|
||||||
|
|
||||||
|
for dir_name in dataset_dirs:
|
||||||
|
matched_disease = next(d for d in disease_ids if dir_name.endswith(d))
|
||||||
|
plant = dir_name[:-(len(matched_disease)+1)] # +1 for hyphen
|
||||||
|
hierarchy[plant].append(matched_disease)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 Split into train/val (85/15)
|
||||||
|
|
||||||
|
Use stratified splitting per class to preserve class distribution.
|
||||||
|
|
||||||
|
- For each disease-plant class, randomly assign 85% to train, 15% to val
|
||||||
|
- Copy files (or symlink) to new directory structure
|
||||||
|
- Verify no data leakage (same image in both splits)
|
||||||
|
|
||||||
|
### 1.3 Build metadata files
|
||||||
|
|
||||||
|
```json
|
||||||
|
// species_index.json
|
||||||
|
{
|
||||||
|
"tomato": ["healthy", "early-blight", "late-blight", "bacterial-spot", ...],
|
||||||
|
"acorn-squash": ["healthy", "powdery-mildew", "downy-mildew", ...],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
// dataset_stats.json
|
||||||
|
{
|
||||||
|
"total_images": 1465818,
|
||||||
|
"total_species": 320,
|
||||||
|
"total_classes": 11499,
|
||||||
|
"images_per_class": { "min": 64, "max": 244, "mean": 127 },
|
||||||
|
"train_images": 1245945,
|
||||||
|
"val_images": 219873,
|
||||||
|
"species_disease_counts": {
|
||||||
|
"tomato": { "early-blight": 156, "late-blight": 142, ... }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.4 Data quality checks
|
||||||
|
|
||||||
|
### 1.4 Image normalization & compression (before splitting)
|
||||||
|
|
||||||
|
**450GB is unnecessarily large for 224px training.** Many source images are high-resolution (e.g., 4000×3000 from phone cameras), but the model only sees 224×224 crops. Resizing to a reasonable max dimension BEFORE training saves massive I/O and enables faster epochs.
|
||||||
|
|
||||||
|
**Strategy**: Resize all images to **max dimension of 512px** (preserving aspect ratio), convert to **JPEG quality 90**.
|
||||||
|
|
||||||
|
| Approach | Est. Size | Pros | Cons |
|
||||||
|
| ------------------------------- | ------------- | -------------------------------------------------- | ---------------------------------- |
|
||||||
|
| **Keep originals** | 450 GB | No quality loss | Slow loading, huge storage |
|
||||||
|
| **Resize 1024px max, JPEG 90** | ~120 GB | Good for future higher-res models | Still somewhat large |
|
||||||
|
| **Resize 512px max, JPEG 90 ✓** | **~60-80 GB** | **Fast loading, enough detail for 224px training** | Can't go back to full res |
|
||||||
|
| **Resize 256px max, JPEG 95** | ~30 GB | Fastest loading | Too small if retrain at higher res |
|
||||||
|
|
||||||
|
**Recommendation**: Resize to 512px max, JPEG q90. This:
|
||||||
|
|
||||||
|
- Reduces storage from 450GB → ~70GB (fits in RAM for caching)
|
||||||
|
- Preserves enough detail for 224×224 RandomResizedCrop augmentation
|
||||||
|
- JPEG is hardware-accelerated (libjpeg-turbo) — fastest decode path
|
||||||
|
- Single format (no more .png/.webp mixed loading)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# resize_and_convert.py
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
|
MAX_SIZE = 512
|
||||||
|
QUALITY = 90
|
||||||
|
|
||||||
|
def process_image(src_path, dst_path):
|
||||||
|
img = Image.open(src_path)
|
||||||
|
# Resize so max dimension = MAX_SIZE, preserving aspect ratio
|
||||||
|
w, h = img.size
|
||||||
|
if max(w, h) > MAX_SIZE:
|
||||||
|
ratio = MAX_SIZE / max(w, h)
|
||||||
|
img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
|
||||||
|
# Convert to RGB (handles RGBA PNGs)
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
# Save as JPEG
|
||||||
|
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
||||||
|
img.save(dst_path, 'JPEG', quality=QUALITY, optimize=True)
|
||||||
|
|
||||||
|
# Run in parallel (Strix Halo has many cores)
|
||||||
|
Parallel(n_jobs=16)(
|
||||||
|
delayed(process_image)(src, dst)
|
||||||
|
for src, dst in image_pairs
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Time estimate on Strix Halo**: ~2-3 hours to resize + convert 1.47M images with 16 parallel workers. Each image takes ~5-10ms with PIL+LANCZOS.
|
||||||
|
|
||||||
|
### 1.5 Data quality checks
|
||||||
|
|
||||||
|
- **Label noise**: Run confidence learning (CleanLab) on a sample to estimate mislabel rate. Web-scraped datasets typically have 8-15% label noise.
|
||||||
|
- **Duplicate detection**: Check for near-duplicate images (perceptual hashing + Hamming distance) within each class.
|
||||||
|
- **Format consistency**: Ensure all images decode successfully; remove corrupted files.
|
||||||
|
- **Background bias**: Verify that no single background dominates a class (subset and eyeball a random grid per class).
|
||||||
|
|
||||||
|
## Edge Cases & Gotchas
|
||||||
|
|
||||||
|
- **Multi-word plant names**: "acorn-squash", "fiddle-leaf-fig", "chili-pepper" — the disease suffix must match the end of the string, not a substring in the plant name. Sorting disease IDs by length (longest first) handles this.
|
||||||
|
- **Disease-less "healthy" dirs**: Need to ensure "healthy" is in the disease list as a valid class (index 0 in current model). Some dirs may be `{plant}-healthy`.
|
||||||
|
- **Cross-platform path length**: Some species+disease combos produce long paths. Use relative symlinks or shorten names if needed on Windows.
|
||||||
|
- **Original files preserved**: The existing `data/dataset/` structure stays untouched; `data/organized/` is a copy.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- [ ] `data/organized/train/` has same total image count as original (minus val split)
|
||||||
|
- [ ] Every class has at least 50 training images
|
||||||
|
- [ ] `species_index.json` covers all 11,499 classes
|
||||||
|
- [ ] No files in both train/ and val/ (no overlap)
|
||||||
|
- [ ] All images readable (no corrupted files)
|
||||||
|
- [ ] Train/val split ratios consistent across all classes (±2%)
|
||||||
309
tasks/hierarchical-model-upgrade/02-hierarchical-training.md
Normal file
309
tasks/hierarchical-model-upgrade/02-hierarchical-training.md
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
# Phase 2 — Hierarchical Model Training
|
||||||
|
|
||||||
|
**Blocked by**: Phase 1 (dataset reorganization)
|
||||||
|
**Blocks**: Phase 3 (export)
|
||||||
|
**Est. time**: 3-5 days on Strix Halo (ROCm), or 4-6 days on RTX 3090 (CUDA)
|
||||||
|
**Machine**: Strix Halo preferred (128GB unified memory + 8TB NVMe at 7,300 MB/s read — SSD is fast enough to stream entire dataset in ~62s)
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Train a hierarchical Swin-Tiny model with two classification heads:
|
||||||
|
|
||||||
|
1. **Species head** (~320 classes) — identifies the plant
|
||||||
|
2. **Disease heads** (one per species, 30-300 classes each) — identifies the disease
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Input Image (224×224×3)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ Swin-Tiny Backbone │ ← pretrained on ImageNet-21K
|
||||||
|
│ (timm library) │ optional: fine-tune on iNaturalist
|
||||||
|
│ output: 768-dim │
|
||||||
|
└──────────┬───────────┘
|
||||||
|
│
|
||||||
|
┌──────┴──────┐
|
||||||
|
▼ ▼
|
||||||
|
┌─────────┐ ┌───────────────┐
|
||||||
|
│ Species │ │ Disease Head │
|
||||||
|
│ Head │ │ (routed by │
|
||||||
|
│ 320 cls │ │ species ID) │
|
||||||
|
└────┬────┘ └───────┬───────┘
|
||||||
|
│ │
|
||||||
|
▼ ▼
|
||||||
|
Species ID Disease ID
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# On Strix Halo (ROCm)
|
||||||
|
python3 -m venv .hierarchical-venv
|
||||||
|
source .hierarchical-venv/bin/activate
|
||||||
|
|
||||||
|
# ROCm PyTorch (install from https://pytorch.org/get-started/locally/)
|
||||||
|
# ROCm 6.x + PyTorch 2.5+
|
||||||
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2
|
||||||
|
|
||||||
|
# Training libs
|
||||||
|
pip install pytorch-lightning timm transformers wandb
|
||||||
|
pip install albumentations opencv-python pillow
|
||||||
|
|
||||||
|
# Data loading (for SSD-optimized streaming)
|
||||||
|
pip install webdataset fsspec # optional benchmarks
|
||||||
|
```
|
||||||
|
|
||||||
|
**Alternative (RTX 3090 CUDA path)**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Protocol
|
||||||
|
|
||||||
|
### Stage A: Species Classifier (2 days)
|
||||||
|
|
||||||
|
| Step | Epochs | LR | Batch Size | Details |
|
||||||
|
| -------------- | ------ | ----------- | ------------------------ | ----------------------------------------------- |
|
||||||
|
| Head warmup | 5 | 1e-3 | 512 (Strix) / 256 (3090) | Backbone frozen, train only species head |
|
||||||
|
| Full fine-tune | 20 | 1e-4 → 1e-6 | 512 / 256 | Unfreeze backbone, cosine LR schedule |
|
||||||
|
| Stage final | 5 | 5e-6 | 512 / 256 | Discriminative LR: backbone layers 0.1× head LR |
|
||||||
|
|
||||||
|
**Loss**: Focal Loss (γ=2.0, α=0.25) — handles any class imbalance in species distribution.
|
||||||
|
|
||||||
|
**Augmentation — Image Jittering, Degradation & Robustness**:
|
||||||
|
|
||||||
|
Real-world plant photos vary dramatically: different cameras, lighting conditions, angles, weather, focus quality, and compression artifacts. **Augmentation is not optional — it's essential for generalization.** The more varied your augmentation, the more robust your model will be when deployed.
|
||||||
|
|
||||||
|
**Three tiers of augmentation**, all applied on-the-fly (never pre-generated):
|
||||||
|
|
||||||
|
#### Tier 1 — Core Geometric & Photometric (applied to every image)
|
||||||
|
|
||||||
|
These simulate the most common real-world variations:
|
||||||
|
|
||||||
|
| Augmentation | Parameter | Simulates |
|
||||||
|
| ------------------------ | ----------------------------------------------------- | -------------------------------------------------- |
|
||||||
|
| RandomResizedCrop | scale=(0.6, 1.0), ratio=(0.75, 1.33) | Different shooting distances, zoom levels, framing |
|
||||||
|
| HorizontalFlip | p=0.5 | Different leaf orientations (left/right symmetry) |
|
||||||
|
| Rotate | limit=45°, p=0.5 | Off-angle photos, tilted camera |
|
||||||
|
| ColorJitter | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 | Different lighting — sunny, overcast, shade, dusk |
|
||||||
|
| RandomBrightnessContrast | brightness_limit=0.2, contrast_limit=0.2, p=0.5 | Exposure variations from auto-exposure cameras |
|
||||||
|
|
||||||
|
#### Tier 2 — Degradation & Quality Simulation (applied to ~30% of images)
|
||||||
|
|
||||||
|
These make the model robust to poor-quality inputs that real users will upload:
|
||||||
|
|
||||||
|
| Augmentation | Parameter | Simulates |
|
||||||
|
| ---------------- | ----------------------------------------- | --------------------------------------------------------- |
|
||||||
|
| GaussianBlur | blur_limit=(3, 7), p=0.2 | Out-of-focus photos, motion blur |
|
||||||
|
| GaussianNoise | var_limit=(10, 50), p=0.15 | Low-light sensor noise, phone camera noise |
|
||||||
|
| ImageCompression | quality_lower=60, quality_upper=95, p=0.2 | JPEG artifacts from compression, social media re-encoding |
|
||||||
|
| RandomGrayscale | p=0.05 | Monochrome cameras, infrared plant imaging |
|
||||||
|
| RandomShadow | shadow_roi=(0, 1, 0, 1), p=0.15 | Leaves in shadow of other leaves/structures outdoors |
|
||||||
|
|
||||||
|
#### Tier 3 — Advanced Regularization (applied at batch level)
|
||||||
|
|
||||||
|
These are cutting-edge techniques that significantly improve generalization on fine-grained classification:
|
||||||
|
|
||||||
|
| Technique | Parameter | Effect |
|
||||||
|
| ------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| **MixUp** | α=0.2 | Blends two random images + labels linearly. Forces the model to learn smooth decision boundaries. Proven +4.8% improvement on rare plant diseases. |
|
||||||
|
| **CutMix** | α=1.0 | Replaces a random patch of one image with another. Forces the model to focus on local lesion features rather than overall leaf shape. |
|
||||||
|
| **RandAugment** | N=2, M=9 | Auto-selected augmentation policy. 14 operations randomly chosen (shear, translate, rotate, contrast, etc.). N=2 ops per image, magnitude 9 (on 0-10 scale). |
|
||||||
|
| **Label Smoothing** | ε=0.1 | Prevents overconfidence on training classes, improves calibration on unseen diseases. |
|
||||||
|
|
||||||
|
**Implementation (albumentations)**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import albumentations as A
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
import kornia.augmentation as K # for GPU-based MixUp/CutMix
|
||||||
|
|
||||||
|
# Core spatial + photometric (Tier 1+2)
|
||||||
|
train_transform = A.Compose([
|
||||||
|
A.RandomResizedCrop(224, 224, scale=(0.6, 1.0), ratio=(0.75, 1.33)),
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.Rotate(limit=45, p=0.5),
|
||||||
|
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
|
||||||
|
# Degradation (Tier 2) — only some images get these
|
||||||
|
A.OneOf([
|
||||||
|
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
|
||||||
|
A.GaussianNoise(var_limit=(10, 50), p=1.0),
|
||||||
|
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
|
||||||
|
], p=0.3),
|
||||||
|
A.ImageCompression(quality_lower=60, quality_upper=95, p=0.2),
|
||||||
|
A.RandomShadow(shadow_roi=(0, 0.5, 0.5, 1), num_shadows_lower=1, num_shadows_upper=2, p=0.15),
|
||||||
|
A.RandomGrayscale(p=0.05),
|
||||||
|
# Normalize
|
||||||
|
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||||||
|
ToTensorV2(),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Validation — minimal, deterministic
|
||||||
|
val_transform = A.Compose([
|
||||||
|
A.Resize(224, 224),
|
||||||
|
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||||||
|
ToTensorV2(),
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
**MixUp/CutMix (implemented in training loop, applied on GPU)**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def mixup_criterion(criterion, pred, y_a, y_b, lam):
|
||||||
|
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
||||||
|
|
||||||
|
for images, labels in dataloader:
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
|
||||||
|
if use_mixup and np.random.random() < 0.5:
|
||||||
|
# MixUp: blend images and labels
|
||||||
|
lam = np.random.beta(0.2, 0.2)
|
||||||
|
indices = torch.randperm(images.size(0)).to(device)
|
||||||
|
mixed_images = lam * images + (1 - lam) * images[indices]
|
||||||
|
|
||||||
|
logits = model(mixed_images)
|
||||||
|
loss = mixup_criterion(criterion, logits, labels, labels[indices], lam)
|
||||||
|
else:
|
||||||
|
logits = model(images)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
```
|
||||||
|
|
||||||
|
## SSD Data Loading Strategy
|
||||||
|
|
||||||
|
**Important**: The full dataset is ~70GB (after Phase 1 resizing), which exceeds the 128GB RAM when accounting for OS, GPU memory, workspace, and model weights. However, **your 8TB NVMe at 7,300 MB/s read changes everything.**
|
||||||
|
|
||||||
|
| Metric | Value |
|
||||||
|
| ----------------------------------------- | --------------- |
|
||||||
|
| Dataset size (after resize) | ~70 GB |
|
||||||
|
| NVMe read speed | 7,300 MB/s |
|
||||||
|
| Sequential read time (full dataset) | **~10 seconds** |
|
||||||
|
| Random read (1000 random files, 4KB each) | ~0.5ms seek |
|
||||||
|
|
||||||
|
**Key insight**: The GPU consumes batches slower than the SSD can deliver them. With `num_workers=8`, each worker reads ~35 images/s from random positions. At 7,300 MB/s sequential, the SSD can serve 150,000+ images/s. The bottleneck is **JPEG decode + augmentation**, not disk I/O.
|
||||||
|
|
||||||
|
**Recommended DataLoader configuration**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=256,
|
||||||
|
shuffle=True,
|
||||||
|
# SSD-optimized settings:
|
||||||
|
num_workers=8, # 8 parallel readers — enough to saturate GPU
|
||||||
|
prefetch_factor=4, # Each worker prefetches 4 batches ahead
|
||||||
|
pin_memory=True, # Faster CPU→GPU transfer via DMA
|
||||||
|
persistent_workers=True, # Keep workers alive between epochs (avoid fork overhead)
|
||||||
|
drop_last=True, # Drop incomplete final batch for consistent batch norm
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why not load everything into RAM?**
|
||||||
|
|
||||||
|
- 128GB total memory — after OS (~8GB), GPU reserved (~4-8GB ROCm), model weights + optimizer states (~4GB), augmentation workspace (~2GB), you have ~100GB free
|
||||||
|
- 70GB dataset would barely fit, but leaves no room for caching augmentation results or handling spikes
|
||||||
|
- Better approach: let the NVMe + DataLoader pipeline stream data. At 7,300 MB/s, reading a batch of 256 images (~50MB) takes ~7ms. Meanwhile, the GPU takes ~200ms to process that batch. **The disk is 30× faster than the GPU — you will never be I/O bound.**
|
||||||
|
|
||||||
|
**Optional: Use WebDataset for maximum throughput**
|
||||||
|
|
||||||
|
WebDataset shards the dataset into large tar files (~1GB each), which are sequentially read. This eliminates random seek overhead entirely — ideal when running at massive scale. For your setup it's optional (raw files on NVMe are already fast enough), but worth considering if you scale to multi-GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install webdataset
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
import webdataset as wds
|
||||||
|
|
||||||
|
urls = "data/organized/train/shard-{000000..000099}.tar"
|
||||||
|
dataset = wds.WebDataset(urls).shuffle(10000).decode("pil").to_tuple("jpg", "cls").map(augment)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Profiling check**: During training, monitor GPU utilization:
|
||||||
|
|
||||||
|
- `nvidia-smi` / `rocm-smi` — GPU-Util should be >90%
|
||||||
|
- If <70%, GPU is waiting for data → increase `num_workers` or `prefetch_factor`
|
||||||
|
- If >95%, data pipeline is keeping up → optimal
|
||||||
|
|
||||||
|
### Stage B: Disease Classifiers (2-3 days)
|
||||||
|
|
||||||
|
| Step | Epochs | LR | Details |
|
||||||
|
| ----------------- | ------ | ---- | ------------------------------------------------------- |
|
||||||
|
| All disease heads | 15 | 1e-3 | Backbone frozen, train all disease heads simultaneously |
|
||||||
|
| Rare-class boost | 5 | 5e-4 | Oversample classes with <80 images |
|
||||||
|
| End-to-end | 10 | 1e-5 | Unfreeze backbone, joint species + disease loss |
|
||||||
|
|
||||||
|
**Key design**: Disease heads are simple linear layers (768 → num_diseases_for_species). Since they share the backbone, inference is efficient — one forward pass through Swin-Tiny, then route the 768-dim feature vector to the correct head.
|
||||||
|
|
||||||
|
**Class balancing**: Use weighted sampler for disease heads — classes with <80 images get 3× sampling weight, classes with <50 images get 5×.
|
||||||
|
|
||||||
|
**Loss weighting**: `L_total = L_species + 0.7 * L_disease` — species loss has higher weight since disease prediction depends on correct species ID.
|
||||||
|
|
||||||
|
## Model Checkpointing
|
||||||
|
|
||||||
|
```
|
||||||
|
checkpoints/
|
||||||
|
├── species_only/ # Stage A checkpoints
|
||||||
|
│ ├── epoch=05-val_loss=0.42.ckpt
|
||||||
|
│ ├── epoch=15-val_loss=0.18.ckpt
|
||||||
|
│ └── epoch=25-best.ckpt
|
||||||
|
├── disease_heads/ # Stage B initial disease heads
|
||||||
|
│ ├── disease_heads_epoch=15.pt
|
||||||
|
│ └── disease_heads_final.pt
|
||||||
|
└── hierarchical_full/ # End-to-end
|
||||||
|
├── epoch=05.ckpt
|
||||||
|
└── epoch=10-best.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
Save every checkpoint with species accuracy, macro F1, and per-disease F1 for the tail classes.
|
||||||
|
|
||||||
|
## Expected Training Time (RTX 3090 baseline)
|
||||||
|
|
||||||
|
| Stage | Epochs | Time/Epoch | Total |
|
||||||
|
| ----------------------- | ------ | ---------- | ---------- |
|
||||||
|
| Species head warmup | 5 | ~18 min | 1.5 hr |
|
||||||
|
| Species full fine-tune | 20 | ~45 min | 15 hr |
|
||||||
|
| Species fine-tune final | 5 | ~45 min | 3.75 hr |
|
||||||
|
| Disease heads | 15 | ~30 min | 7.5 hr |
|
||||||
|
| Disease rare-class | 5 | ~35 min | 3 hr |
|
||||||
|
| End-to-end | 10 | ~50 min | 8.5 hr |
|
||||||
|
| **Total** | **60** | | **~39 hr** |
|
||||||
|
|
||||||
|
On **Strix Halo with NVMe + ROCm** the time/epoch should be **significantly faster** due to:
|
||||||
|
|
||||||
|
- 7,300 MB/s NVMe: data loads faster than GPU can consume it (zero I/O wait)
|
||||||
|
- Larger batch sizes (512 vs 256): fewer iterations per epoch
|
||||||
|
- ROCm 6.x has strong PyTorch performance on AMD GPUs
|
||||||
|
- 128GB RAM allows large prefetch buffer for seamless streaming
|
||||||
|
|
||||||
|
Expect **~20-28 hours** total on Strix Halo.
|
||||||
|
|
||||||
|
## Evaluation Metrics
|
||||||
|
|
||||||
|
| Metric | Target | Measurement |
|
||||||
|
| -------------------------------- | ------ | --------------------------------------- | ---------------- |
|
||||||
|
| Species top-1 accuracy | ≥95% | Fraction of correct species predictions |
|
||||||
|
| Disease top-1 accuracy | ≥88% | Across all species-conditioned heads |
|
||||||
|
| Disease top-3 accuracy | ≥94% | Model correct if disease is in top 3 |
|
||||||
|
| Macro F1 (rare diseases) | ≥80% | Weighted average across tail classes |
|
||||||
|
| Species→Disease cascade accuracy | ≥90% | P(correct species) × P(correct disease | correct species) |
|
||||||
|
|
||||||
|
## Edge Cases & Gotchas
|
||||||
|
|
||||||
|
- **GPU memory on RTX 3090 (24GB)**: Swin-Tiny at 224px with batch size 256 + mixed precision should fit. If not, reduce to 128 or use gradient accumulation (accumulate 2 steps).
|
||||||
|
- **Strix Halo ROCm quirks**: `torch.compile()` may have issues on ROCm 6.2 — test without it first. Some `timm` model ops may need ROCm kernel fallbacks; test the forward pass before starting training.
|
||||||
|
- **Checkpoint compatibility**: Save in pure PyTorch format (`.pt`), not Lightning-specific, so they're loadable outside Lightning for export.
|
||||||
|
- **Disease head memory**: 320 separate linear layers sounds large, but each is 768×N_diseases (avg ~300 → 230K params). Total disease head params: ~70M (vs 28M for backbone). This is fine — compute is dominated by the backbone.
|
||||||
|
- **Loss divergence on rare diseases**: Monitor individual disease loss curves; if a tail class diverges, reduce its learning rate or use gradient clipping (max_norm=1.0).
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- [ ] Species classifier ≥95% top-1 on val set
|
||||||
|
- [ ] Disease top-3 accuracy ≥94% on val set
|
||||||
|
- [ ] Confusion matrix shows no systematic species misclassifications
|
||||||
|
- [ ] Per-species disease classifiers all converge (no NaN losses)
|
||||||
|
- [ ] Tail classes (≤80 images) have F1 ≥70%
|
||||||
|
- [ ] Model can be loaded from checkpoint and run inference in PyTorch
|
||||||
|
- [ ] No OOM errors during training
|
||||||
165
tasks/hierarchical-model-upgrade/03-export-quantization.md
Normal file
165
tasks/hierarchical-model-upgrade/03-export-quantization.md
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
# Phase 3 — ONNX Export & Quantization
|
||||||
|
|
||||||
|
**Blocked by**: Phase 2 (trained model)
|
||||||
|
**Blocks**: Phase 4 (server inference)
|
||||||
|
**Est. time**: 1-2 days
|
||||||
|
**Machine**: Any (RTX 3090 recommended for ONNX GPU validation)
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment.
|
||||||
|
|
||||||
|
## Deliverables
|
||||||
|
|
||||||
|
```
|
||||||
|
public/models/
|
||||||
|
├── swin-species.onnx # FP16 species model (3.2 MB)
|
||||||
|
├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB)
|
||||||
|
├── disease-heads/ # One ONNX per species
|
||||||
|
│ ├── tomato-int8.onnx
|
||||||
|
│ ├── acorn-squash-int8.onnx
|
||||||
|
│ └── ...
|
||||||
|
├── disease-heads-list.json # Maps species ID → ONNX file path
|
||||||
|
├── ood-detector.pkl # Mahalanobis parameters for OOD
|
||||||
|
└── onnx-metadata.json # Input/output shapes, versions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Steps
|
||||||
|
|
||||||
|
### 3.1 Export backbone + species head as single ONNX
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import onnx
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Export end-to-end species model (backbone + species head)
|
||||||
|
dummy = torch.randn(1, 3, 224, 224)
|
||||||
|
torch.onnx.export(
|
||||||
|
model, # Combined forward: backbone + species_head
|
||||||
|
dummy,
|
||||||
|
"public/models/swin-species.onnx",
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["species_logits", "embedding"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input": {0: "batch_size"},
|
||||||
|
"species_logits": {0: "batch_size"},
|
||||||
|
"embedding": {0: "batch_size"},
|
||||||
|
},
|
||||||
|
opset_version=17,
|
||||||
|
do_constant_folding=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key**: Export the 768-dim `embedding` as a second output — the server needs it to route to the correct disease head.
|
||||||
|
|
||||||
|
### 3.2 Export disease heads individually
|
||||||
|
|
||||||
|
Each disease head is a simple `nn.Linear(768, N_diseases)`. Export as a mini-ONNX that takes the embedding and returns disease logits:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for species_name, head in model.disease_heads.items():
|
||||||
|
dummy_embed = torch.randn(1, 768)
|
||||||
|
torch.onnx.export(
|
||||||
|
head, dummy_embed,
|
||||||
|
f"public/models/disease-heads/{species_name}.onnx",
|
||||||
|
input_names=["embedding"],
|
||||||
|
output_names=["disease_logits"],
|
||||||
|
dynamic_axes={"embedding": {0: "batch_size"}},
|
||||||
|
opset_version=17,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Total**: ~320 small ONNX files, each ~50-200 KB.
|
||||||
|
|
||||||
|
### 3.3 INT8 Quantization
|
||||||
|
|
||||||
|
Use ONNX Runtime's quantization tooling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||||
|
|
||||||
|
# Quantize species model
|
||||||
|
quantize_dynamic(
|
||||||
|
"public/models/swin-species.onnx",
|
||||||
|
"public/models/swin-species-int8.onnx",
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quantize disease heads (batch)
|
||||||
|
for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")):
|
||||||
|
quantize_dynamic(
|
||||||
|
str(onnx_path),
|
||||||
|
str(onnx_path.with_suffix("-int8.onnx")),
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Accuracy impact**: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers.
|
||||||
|
|
||||||
|
### 3.4 OOD Detector
|
||||||
|
|
||||||
|
Train a Mahalanobis distance-based OOD detector on the training set embeddings:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
from scipy.spatial.distance import mahalanobis
|
||||||
|
|
||||||
|
# Collect embeddings from training set
|
||||||
|
embeddings = []
|
||||||
|
for batch in val_dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
_, emb = model(batch["image"])
|
||||||
|
embeddings.append(emb.numpy())
|
||||||
|
embeddings = np.vstack(embeddings)
|
||||||
|
|
||||||
|
# Fit multivariate Gaussian
|
||||||
|
mean = np.mean(embeddings, axis=0)
|
||||||
|
cov = np.cov(embeddings, rowvar=False)
|
||||||
|
inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))
|
||||||
|
|
||||||
|
# Save for inference
|
||||||
|
import pickle
|
||||||
|
with open("public/models/ood-detector.pkl", "wb") as f:
|
||||||
|
pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f)
|
||||||
|
```
|
||||||
|
|
||||||
|
The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD.
|
||||||
|
|
||||||
|
### 3.5 Accuracy verification
|
||||||
|
|
||||||
|
Before committing to ONNX, verify against PyTorch:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
# Compare PyTorch vs ONNX outputs
|
||||||
|
pytorch_out = model(sample_image)
|
||||||
|
ort_out = ort.InferenceSession("swin-species-int8.onnx").run(
|
||||||
|
["species_logits"], {"input": sample_image.numpy()}
|
||||||
|
)
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0]))
|
||||||
|
assert max_diff < 0.01, f"ONNX mismatch: {max_diff}"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Edge Cases & Gotchas
|
||||||
|
|
||||||
|
- **ONNX opset compatibility**: Some `timm` model ops (like `roll` in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19.
|
||||||
|
- **Dynamic axes**: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions.
|
||||||
|
- **Disease head routing**: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly.
|
||||||
|
- **Strix Halo ROCm + ONNX**: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both.
|
||||||
|
- **Disease head file count**: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference).
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- [ ] ONNX species model output matches PyTorch output (max diff < 0.01)
|
||||||
|
- [ ] INT8 accuracy within 1% of FP16 on val set (sample 10K images)
|
||||||
|
- [ ] ONNX model loads in ONNX Runtime without errors
|
||||||
|
- [ ] All 320+ disease heads export successfully
|
||||||
|
- [ ] OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision
|
||||||
|
- [ ] ONNX model size < 5MB (INT8) for species, < 200KB per disease head
|
||||||
|
- [ ] Inference on CPU (Strix Halo) < 200ms for species + disease combined
|
||||||
211
tasks/hierarchical-model-upgrade/04-server-inference.md
Normal file
211
tasks/hierarchical-model-upgrade/04-server-inference.md
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
# Phase 4 — Server Inference Pipeline
|
||||||
|
|
||||||
|
**Blocked by**: Phase 3 (exported ONNX models)
|
||||||
|
**Blocks**: Phase 5 (hybrid integration)
|
||||||
|
**Est. time**: 2-3 days
|
||||||
|
**Machine**: Strix Halo (will serve inference in production)
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/identify
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor
|
||||||
|
│ (sharp + buffer) │ Normalize with ImageNet stats
|
||||||
|
└──────────┬───────────┘
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded)
|
||||||
|
│ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold
|
||||||
|
└──────────┬───────────┘
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ 3. Species Inference │ Run swin-species-int8.onnx
|
||||||
|
│ (ONNX Runtime) │ softmax over 320 species logits
|
||||||
|
│ │ Return top-1 species + embedding vector
|
||||||
|
└──────────┬───────────┘
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ 4. Disease Routing │ Look up disease head ONNX for predicted species
|
||||||
|
│ (species→head) │ Feed embedding through disease head
|
||||||
|
│ │ softmax over species-conditional disease logits
|
||||||
|
└──────────┬───────────┘
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ 5. Enrichment │ Map class indices → disease/plant objects from
|
||||||
|
│ (knowledge base) │ src/data/diseases.json and src/data/plants.json
|
||||||
|
│ │ Return top-K with treatment info
|
||||||
|
└──────────┬───────────┘
|
||||||
|
▼
|
||||||
|
JSON Response
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── lib/
|
||||||
|
│ ├── server/
|
||||||
|
│ │ ├── inference-server.ts ← Main orchestration
|
||||||
|
│ │ ├── onnx-loader.ts ← ONNX runtime session manager
|
||||||
|
│ │ ├── ood-detector.ts ← Mahalanobis OOD detection
|
||||||
|
│ │ ├── species-classifier.ts ← Species ONNX inference
|
||||||
|
│ │ ├── disease-classifier.ts ← Disease head routing + inference
|
||||||
|
│ │ └── image-preprocessor.ts ← Sharp-based preprocessing
|
||||||
|
│ └── ml/
|
||||||
|
│ ├── inference.ts ← Existing browser inference (kept as-is)
|
||||||
|
│ └── ...
|
||||||
|
└── app/
|
||||||
|
├── api/
|
||||||
|
│ ├── identify/route.ts ← Existing endpoint (keep for backward compat)
|
||||||
|
│ └── identify-v2/route.ts ← New server-side endpoint
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Implementation Details
|
||||||
|
|
||||||
|
### 4.1 ONNX Session Manager
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// src/lib/server/onnx-loader.ts
|
||||||
|
import ort from "onnxruntime-node";
|
||||||
|
|
||||||
|
const sessions = new Map<string, ort.InferenceSession>();
|
||||||
|
|
||||||
|
export async function getOrCreateSession(path: string): Promise<ort.InferenceSession> {
|
||||||
|
if (!sessions.has(path)) {
|
||||||
|
sessions.set(
|
||||||
|
path,
|
||||||
|
await ort.InferenceSession.create(path, {
|
||||||
|
executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo
|
||||||
|
graphOptimizationLevel: "all",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return sessions.get(path)!;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Execution provider**: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part).
|
||||||
|
|
||||||
|
### 4.2 Lazy Loading Strategy
|
||||||
|
|
||||||
|
Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const diseaseHeadCache = new Map<string, ort.InferenceSession>();
|
||||||
|
|
||||||
|
async function getDiseaseHead(speciesName: string) {
|
||||||
|
if (!diseaseHeadCache.has(speciesName)) {
|
||||||
|
const path = `public/models/disease-heads/${speciesName}-int8.onnx`;
|
||||||
|
diseaseHeadCache.set(speciesName, await createSession(path));
|
||||||
|
}
|
||||||
|
return diseaseHeadCache.get(speciesName)!;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 Main Inference Pipeline
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// src/lib/server/inference-server.ts
|
||||||
|
export async function serverIdentify(imageBuffer: Buffer): Promise<InferenceResult> {
|
||||||
|
const start = performance.now();
|
||||||
|
|
||||||
|
// 1. Preprocess
|
||||||
|
const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224]
|
||||||
|
|
||||||
|
// 2. OOD detection (quick, using embedding from species model)
|
||||||
|
const oodResult = await oodDetect(tensor);
|
||||||
|
if (!oodResult.isPlant) {
|
||||||
|
return {
|
||||||
|
error: "No plant detected",
|
||||||
|
confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold,
|
||||||
|
inferenceTimeMs: Math.round(performance.now() - start),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Species inference
|
||||||
|
const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx");
|
||||||
|
const speciesOutput = await speciesSession.run({
|
||||||
|
input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]),
|
||||||
|
});
|
||||||
|
const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array);
|
||||||
|
const speciesProbs = softmax(speciesLogits);
|
||||||
|
const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0];
|
||||||
|
const embedding = speciesOutput.embedding.data as Float32Array;
|
||||||
|
|
||||||
|
// 4. Disease inference (routed by species)
|
||||||
|
const speciesName = speciesIndex[topSpeciesIdx];
|
||||||
|
const diseaseSession = await getDiseaseHead(speciesName);
|
||||||
|
const diseaseOutput = await diseaseSession.run({
|
||||||
|
embedding: new ort.Tensor("float32", embedding, [1, 768]),
|
||||||
|
});
|
||||||
|
const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array);
|
||||||
|
const diseaseProbs = softmax(diseaseLogits);
|
||||||
|
const topDiseases = topK(diseaseProbs, 5);
|
||||||
|
|
||||||
|
// 5. Enrichment
|
||||||
|
const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases);
|
||||||
|
|
||||||
|
return {
|
||||||
|
species: { id: speciesName, confidence: topSpeciesProb },
|
||||||
|
diseases: enriched,
|
||||||
|
oodScore: oodResult.mahalanobisDistance,
|
||||||
|
inferenceTimeMs: Math.round(performance.now() - start),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 API Route
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// src/app/api/identify-v2/route.ts
|
||||||
|
export async function POST(req: Request) {
|
||||||
|
const formData = await req.formData();
|
||||||
|
const image = formData.get("image") as File;
|
||||||
|
|
||||||
|
if (!image || !image.type.startsWith("image/")) {
|
||||||
|
return Response.json({ error: "Invalid image" }, { status: 400 });
|
||||||
|
}
|
||||||
|
|
||||||
|
const buffer = Buffer.from(await image.arrayBuffer());
|
||||||
|
const result = await serverIdentify(buffer);
|
||||||
|
|
||||||
|
return Response.json(result);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 Caching Strategy
|
||||||
|
|
||||||
|
- **Model sessions**: Cache ONNX sessions in memory (warm on first request per deployment)
|
||||||
|
- **Disease heads**: Cache top-50 most common species' disease heads (LRU eviction)
|
||||||
|
- **Image preprocessing results**: Do NOT cache — each image is unique
|
||||||
|
- **Response caching**: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image)
|
||||||
|
|
||||||
|
## Edge Cases & Gotchas
|
||||||
|
|
||||||
|
- **Cold start latency**: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot.
|
||||||
|
- **Disease head not found**: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result.
|
||||||
|
- **Large images**: Client may upload 12MP photos. Resize to 224×224 _before_ feeding to ONNX (sharp is fast for this). Set a 10MB upload limit.
|
||||||
|
- **Concurrent requests**: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent `run()` calls).
|
||||||
|
- **Memory**: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8).
|
||||||
|
- **Error handling**: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- [ ] `POST /api/identify-v2` returns valid JSON with species + disease predictions
|
||||||
|
- [ ] Cold start (first ever request): < 3 seconds (model loading)
|
||||||
|
- [ ] Warm requests: < 200ms total (OOD + species + disease + enrichment)
|
||||||
|
- [ ] OOD detection correctly rejects non-plant images (rocks, buildings, animals)
|
||||||
|
- [ ] OOD detection correctly accepts plant images (false rejection rate < 1%)
|
||||||
|
- [ ] All 320+ species → disease head routes resolve correctly
|
||||||
|
- [ ] Large image (12MP) → preprocessed to 224×224 without OOM
|
||||||
|
- [ ] Concurrent 10 requests handled without errors or slowdown
|
||||||
|
- [ ] Degraded mode works if ONNX model fails (falls back to existing TF.js)
|
||||||
|
- [ ] Health endpoint reports model status, last inference time, error count
|
||||||
208
tasks/hierarchical-model-upgrade/05-browser-hybrid.md
Normal file
208
tasks/hierarchical-model-upgrade/05-browser-hybrid.md
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# Phase 5 — Browser Model & Hybrid Integration
|
||||||
|
|
||||||
|
**Blocked by**: Phase 4 (server inference pipeline)
|
||||||
|
**Est. time**: 2-3 days
|
||||||
|
**Machine**: Any (development on Strix Halo or M3 Pro)
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Train a lightweight browser-compatible model (TF.js) and implement the hybrid routing logic: fast first pass in-browser, server fallback when confidence is low.
|
||||||
|
|
||||||
|
## Hybrid Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
User uploads image
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ Browser: │
|
||||||
|
│ EfficientNet-Lite │ ← ~5MB TF.js model in browser
|
||||||
|
│ (TF.js) │ Predicts species + top-5 diseases
|
||||||
|
│ │
|
||||||
|
│ Species confidence? │
|
||||||
|
│ ┌────┴────┐ │
|
||||||
|
│ │ ≥90% │ <90% │
|
||||||
|
│ └────┬────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ Show result │
|
||||||
|
│ (instant) │ │
|
||||||
|
└────────────┼────────┘
|
||||||
|
│ (background if >90%,
|
||||||
|
│ foreground if <90%)
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ Server: │
|
||||||
|
│ Full Swin-Tiny │ ← Only when browser is uncertain
|
||||||
|
│ (ONNX Runtime) │ or user requests "detailed analysis"
|
||||||
|
│ │
|
||||||
|
│ Returns enriched │
|
||||||
|
│ results with full │
|
||||||
|
│ treatment info │
|
||||||
|
└──────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Steps
|
||||||
|
|
||||||
|
### 5.1 Train lightweight browser model
|
||||||
|
|
||||||
|
Use the hierarchical training data to train a **EfficientNet-Lite0** model that outputs both species and disease predictions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import timm
|
||||||
|
import tensorflow as tf # For TF.js export
|
||||||
|
|
||||||
|
# Train in PyTorch first (for accuracy), then convert
|
||||||
|
model = timm.create_model('efficientnet_lite0', pretrained=True)
|
||||||
|
# Add: species head (320) + disease head (11,499 flat)
|
||||||
|
# Or use hierarchical with just top-50 diseases per species
|
||||||
|
|
||||||
|
# Training: 10 epochs frozen backbone, 10 epochs fine-tune
|
||||||
|
# Target: <5MB model size, runs in <100ms on mobile device
|
||||||
|
```
|
||||||
|
|
||||||
|
**Export to TF.js**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Convert PyTorch → ONNX → TF.js
|
||||||
|
python -m tf2onnx.convert --pytorch-model browser_model.pt --output browser_model.onnx
|
||||||
|
tensorflowjs_converter --input_format=tf_saved_model browser_model/ browser_tfjs/
|
||||||
|
```
|
||||||
|
|
||||||
|
**Model size target**: < 5MB (EfficientNet-Lite0 is ~4.7MB with INT8 quantization).
|
||||||
|
|
||||||
|
### 5.2 Browser inference integration
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// src/lib/ml/inference.ts — Updated with hybrid routing
|
||||||
|
|
||||||
|
export type InferenceSource = "browser" | "server";
|
||||||
|
export type InferenceMode = "quick" | "detailed";
|
||||||
|
|
||||||
|
export async function identifyPlant(
|
||||||
|
image: HTMLImageElement | File,
|
||||||
|
mode: InferenceMode = "quick",
|
||||||
|
): Promise<InferenceResult> {
|
||||||
|
// 1. Run browser model (always, it's fast)
|
||||||
|
const browserResult = await runBrowserInference(image);
|
||||||
|
|
||||||
|
// 2. Decide: is this confident enough?
|
||||||
|
if (mode === "quick" && browserResult.topConfidence >= 0.9) {
|
||||||
|
// Browser alone is sufficient
|
||||||
|
return {
|
||||||
|
...browserResult,
|
||||||
|
source: "browser",
|
||||||
|
inferenceTimeMs: browserResult.inferenceTimeMs,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fall back to server for detailed analysis
|
||||||
|
const serverResult = await runServerInference(image);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...serverResult,
|
||||||
|
source: "server",
|
||||||
|
browserConfidence: browserResult.topConfidence,
|
||||||
|
serverConfidence: serverResult.topConfidence,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async function runBrowserInference(image: HTMLImageElement): Promise<BrowserResult> {
|
||||||
|
const model = await getBrowserModel(); // Lazy load EfficientNet-Lite
|
||||||
|
const tensor = await preprocessBrowser(image); // TF.js preprocessing
|
||||||
|
const output = await model.predict(tensor);
|
||||||
|
return parseOutput(output);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 UI integration
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// src/components/ImageUpload.tsx — Updated
|
||||||
|
|
||||||
|
function ImageUpload() {
|
||||||
|
const [result, setResult] = useState<InferenceResult | null>(null);
|
||||||
|
const [mode, setMode] = useState<InferenceMode>('quick');
|
||||||
|
const [source, setSource] = useState<InferenceSource | null>(null);
|
||||||
|
|
||||||
|
async function handleUpload(image: File) {
|
||||||
|
// Run browser model (instant)
|
||||||
|
const browserResult = await identifyPlant(image, 'quick');
|
||||||
|
setResult(browserResult);
|
||||||
|
setSource(browserResult.source);
|
||||||
|
|
||||||
|
// If server was called in background, show loading indicator
|
||||||
|
if (browserResult.source === 'server') {
|
||||||
|
// Show "Getting detailed analysis..." spinner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ImageUploader onUpload={handleUpload} />
|
||||||
|
{result && (
|
||||||
|
<div>
|
||||||
|
<ResultCard result={result} />
|
||||||
|
<ConfidenceBadge
|
||||||
|
confidence={result.topConfidence}
|
||||||
|
source={source} // "browser" or "server"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 User-facing indication
|
||||||
|
|
||||||
|
Show a subtle badge indicating which model made the prediction:
|
||||||
|
|
||||||
|
| Source | Badge | UX |
|
||||||
|
| ------------------- | -------------------- | ------------------------------------- |
|
||||||
|
| Browser (high conf) | ✅ Instant ID | Green badge, "Analyzed on device" |
|
||||||
|
| Server (full model) | 🧠 Detailed Analysis | Blue badge, "Deep analysis" |
|
||||||
|
| Server (fallback) | 🔄 Upgraded | Yellow badge, "Upgraded for accuracy" |
|
||||||
|
|
||||||
|
### 5.5 Progressive enhancement
|
||||||
|
|
||||||
|
The system should degrade gracefully:
|
||||||
|
|
||||||
|
| Scenario | Behavior |
|
||||||
|
| ---------------------------------- | --------------------------------------------------------------------- |
|
||||||
|
| Offline | Browser model only (may be less accurate for unusual diseases) |
|
||||||
|
| Slow network | Browser model shows results immediately, server updates in background |
|
||||||
|
| Server down | Browser model alone, with note: "Limited to quick analysis" |
|
||||||
|
| New disease (not in browser model) | Server model handles it, browser shows "could be unusual" |
|
||||||
|
| No camera / file | Error message, "Upload an image to identify" |
|
||||||
|
|
||||||
|
## Edge Cases & Gotchas
|
||||||
|
|
||||||
|
- **Model loading race**: If the browser model hasn't loaded yet, show a loading spinner rather than falling through to server. Lazy-load the model on page mount.
|
||||||
|
- **Discrepancy between browser and server**: If browser and server disagree on the top prediction, show both with confidence bars. The server model is authoritative.
|
||||||
|
- **Retina / high-DPI images**: TF.js may handle these differently from ONNX. Ensure preprocessing (resize, normalize) produces identical tensors.
|
||||||
|
- **Cache busting**: When the model is updated, increment a version hash in the URL to avoid stale cached models.
|
||||||
|
- **Memory**: EfficientNet-Lite takes ~5MB in memory. Older phones may struggle; add a cleanup step after inference (`model.dispose()`).
|
||||||
|
|
||||||
|
## Performance Targets
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
| ------------------------------- | -------------------------------- |
|
||||||
|
| Browser model load time (warm) | < 1s |
|
||||||
|
| Browser model inference | < 100ms |
|
||||||
|
| Server model inference (warm) | < 200ms |
|
||||||
|
| Hybrid fast path (browser only) | < 200ms total |
|
||||||
|
| Hybrid server path | < 1.5s total (including network) |
|
||||||
|
| Model file size (browser) | < 5MB |
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- [ ] Browser model loads in Chrome, Firefox, Safari (desktop + mobile)
|
||||||
|
- [ ] Browser model inference completes in < 100ms on mid-range phone
|
||||||
|
- [ ] Hybrid routing works: conf ≥90% → browser result, conf <90% → server result
|
||||||
|
- [ ] Server fallback fires within 200ms of browser model completing
|
||||||
|
- [ ] UI shows source badge ("Instant ID" vs "Deep Analysis")
|
||||||
|
- [ ] Offline mode: browser model works without network
|
||||||
|
- [ ] Server degraded: system still works with browser model only
|
||||||
|
- [ ] No memory leaks on repeated inferences (10+ images in succession)
|
||||||
|
- [ ] Identical image produces same top prediction on browser and server (within margin)
|
||||||
|
- [ ] All existing tests pass with hybrid pipeline
|
||||||
63
tasks/hierarchical-model-upgrade/README.md
Normal file
63
tasks/hierarchical-model-upgrade/README.md
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Hierarchical Model Architecture Upgrade
|
||||||
|
|
||||||
|
**Scale**: 1.47M images across 11,499 disease-plant classes
|
||||||
|
**Goal**: Replace flat MobileNetV2 (38-class PlantVillage) with hierarchical Swin-Tiny (species → disease)
|
||||||
|
**Deployment**: Hybrid — lightweight browser model (TF.js) + full server model (ONNX Runtime)
|
||||||
|
|
||||||
|
## Hardware
|
||||||
|
|
||||||
|
| Machine | Role | Specs |
|
||||||
|
| -------------- | ------------------------------- | ---------------------------------------- |
|
||||||
|
| **Strix Halo** | Primary training + inference | AI 395+ MAX (ROCm), 128GB unified memory |
|
||||||
|
| **RTX 3090** | Secondary training / CUDA path | 24GB VRAM |
|
||||||
|
| **M3 Pro** | Development only (work machine) | — |
|
||||||
|
|
||||||
|
**Key advantage**: Strix Halo's 128GB unified memory allows loading the entire 1.5M image dataset into RAM and training with extremely large effective batch sizes — the GPU accesses the full 128GB pool, no VRAM ceiling.
|
||||||
|
|
||||||
|
## Status Legend
|
||||||
|
|
||||||
|
```
|
||||||
|
[ ] not started [~] in progress [x] done [-] skipped
|
||||||
|
```
|
||||||
|
|
||||||
|
## Task Map
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 1 ──→ Phase 2 ──→ Phase 3 ──→ Phase 4 ──→ Phase 5
|
||||||
|
Dataset Model Model Server Integration
|
||||||
|
Reorg Training Export Inference + Testing
|
||||||
|
& Quant. Pipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
## Phases
|
||||||
|
|
||||||
|
- [ ] [Phase 1 — Dataset Reorganization](01-dataset-reorganization.md)
|
||||||
|
Parse 11,499 flat directories into hierarchical species→disease structure, create train/val splits, build species index.
|
||||||
|
- [ ] [Phase 2 — Hierarchical Model Training](02-hierarchical-training.md)
|
||||||
|
Train Swin-Tiny backbone + species head + disease heads using PyTorch + ROCm on Strix Halo.
|
||||||
|
- [ ] [Phase 3 — ONNX Export & Quantization](03-export-quantization.md)
|
||||||
|
Export trained models to ONNX, apply INT8 quantization, verify accuracy.
|
||||||
|
- [ ] [Phase 4 — Server Inference Pipeline](04-server-inference.md)
|
||||||
|
Build server-side inference API with ONNX Runtime, OOD detection, species routing.
|
||||||
|
- [ ] [Phase 5 — Browser Model & Hybrid Integration](05-browser-hybrid.md)
|
||||||
|
Lightweight TF.js model for client, hybrid confidence-based routing, full integration.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
01 (dataset) ──→ 02 (training) ──→ 03 (export) ──→ 04 (server)
|
||||||
|
│
|
||||||
|
└──→ 05 (browser + hybrid)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Exit Criteria
|
||||||
|
|
||||||
|
- [ ] Species classifier achieves ≥95% top-1 accuracy on held-out val set
|
||||||
|
- [ ] Disease classifiers achieve ≥90% top-3 accuracy per species
|
||||||
|
- [ ] ONNX INT8 models infer in <200ms on CPU, <50ms on GPU
|
||||||
|
- [ ] Browser TF.js model loads and runs in <100ms on mid-range devices
|
||||||
|
- [ ] Hybrid routing works: high-confidence results served instantly from browser
|
||||||
|
- [ ] Server fallback fires automatically when browser confidence is low
|
||||||
|
- [ ] OOD detection rejects non-plant images with ≥99% precision
|
||||||
|
- [ ] Full integration: upload → result in <500ms (browser) or <1s (server)
|
||||||
|
- [ ] Existing app functionality preserved (all routes, pages, API endpoints)
|
||||||
Reference in New Issue
Block a user