# Phase 4 — Server Inference Pipeline **Blocked by**: Phase 3 (exported ONNX models) **Blocks**: Phase 5 (hybrid integration) **Est. time**: 2-3 days **Machine**: Strix Halo (will serve inference in production) ## Objective Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results. ## Architecture ``` POST /api/identify │ ▼ ┌──────────────────────┐ │ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor │ (sharp + buffer) │ Normalize with ImageNet stats └──────────┬───────────┘ ▼ ┌──────────────────────┐ │ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded) │ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold └──────────┬───────────┘ ▼ ┌──────────────────────┐ │ 3. Species Inference │ Run swin-species-int8.onnx │ (ONNX Runtime) │ softmax over 320 species logits │ │ Return top-1 species + embedding vector └──────────┬───────────┘ ▼ ┌──────────────────────┐ │ 4. Disease Routing │ Look up disease head ONNX for predicted species │ (species→head) │ Feed embedding through disease head │ │ softmax over species-conditional disease logits └──────────┬───────────┘ ▼ ┌──────────────────────┐ │ 5. Enrichment │ Map class indices → disease/plant objects from │ (knowledge base) │ src/data/diseases.json and src/data/plants.json │ │ Return top-K with treatment info └──────────┬───────────┘ ▼ JSON Response ``` ## File Structure ``` src/ ├── lib/ │ ├── server/ │ │ ├── inference-server.ts ← Main orchestration │ │ ├── onnx-loader.ts ← ONNX runtime session manager │ │ ├── ood-detector.ts ← Mahalanobis OOD detection │ │ ├── species-classifier.ts ← Species ONNX inference │ │ ├── disease-classifier.ts ← Disease head routing + inference │ │ └── image-preprocessor.ts ← Sharp-based preprocessing │ └── ml/ │ ├── inference.ts ← Existing browser inference (kept as-is) │ └── ... └── app/ ├── api/ │ ├── identify/route.ts ← Existing endpoint (keep for backward compat) │ └── identify-v2/route.ts ← New server-side endpoint └── ... ``` ## Key Implementation Details ### 4.1 ONNX Session Manager ```typescript // src/lib/server/onnx-loader.ts import ort from "onnxruntime-node"; const sessions = new Map(); export async function getOrCreateSession(path: string): Promise { if (!sessions.has(path)) { sessions.set( path, await ort.InferenceSession.create(path, { executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo graphOptimizationLevel: "all", }), ); } return sessions.get(path)!; } ``` **Execution provider**: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part). ### 4.2 Lazy Loading Strategy Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them: ```typescript const diseaseHeadCache = new Map(); async function getDiseaseHead(speciesName: string) { if (!diseaseHeadCache.has(speciesName)) { const path = `public/models/disease-heads/${speciesName}-int8.onnx`; diseaseHeadCache.set(speciesName, await createSession(path)); } return diseaseHeadCache.get(speciesName)!; } ``` ### 4.3 Main Inference Pipeline ```typescript // src/lib/server/inference-server.ts export async function serverIdentify(imageBuffer: Buffer): Promise { const start = performance.now(); // 1. Preprocess const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224] // 2. OOD detection (quick, using embedding from species model) const oodResult = await oodDetect(tensor); if (!oodResult.isPlant) { return { error: "No plant detected", confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold, inferenceTimeMs: Math.round(performance.now() - start), }; } // 3. Species inference const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx"); const speciesOutput = await speciesSession.run({ input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]), }); const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array); const speciesProbs = softmax(speciesLogits); const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0]; const embedding = speciesOutput.embedding.data as Float32Array; // 4. Disease inference (routed by species) const speciesName = speciesIndex[topSpeciesIdx]; const diseaseSession = await getDiseaseHead(speciesName); const diseaseOutput = await diseaseSession.run({ embedding: new ort.Tensor("float32", embedding, [1, 768]), }); const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array); const diseaseProbs = softmax(diseaseLogits); const topDiseases = topK(diseaseProbs, 5); // 5. Enrichment const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases); return { species: { id: speciesName, confidence: topSpeciesProb }, diseases: enriched, oodScore: oodResult.mahalanobisDistance, inferenceTimeMs: Math.round(performance.now() - start), }; } ``` ### 4.4 API Route ```typescript // src/app/api/identify-v2/route.ts export async function POST(req: Request) { const formData = await req.formData(); const image = formData.get("image") as File; if (!image || !image.type.startsWith("image/")) { return Response.json({ error: "Invalid image" }, { status: 400 }); } const buffer = Buffer.from(await image.arrayBuffer()); const result = await serverIdentify(buffer); return Response.json(result); } ``` ### 4.5 Caching Strategy - **Model sessions**: Cache ONNX sessions in memory (warm on first request per deployment) - **Disease heads**: Cache top-50 most common species' disease heads (LRU eviction) - **Image preprocessing results**: Do NOT cache — each image is unique - **Response caching**: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image) ## Edge Cases & Gotchas - **Cold start latency**: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot. - **Disease head not found**: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result. - **Large images**: Client may upload 12MP photos. Resize to 224×224 _before_ feeding to ONNX (sharp is fast for this). Set a 10MB upload limit. - **Concurrent requests**: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent `run()` calls). - **Memory**: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8). - **Error handling**: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode. ## Verification - [ ] `POST /api/identify-v2` returns valid JSON with species + disease predictions - [ ] Cold start (first ever request): < 3 seconds (model loading) - [ ] Warm requests: < 200ms total (OOD + species + disease + enrichment) - [ ] OOD detection correctly rejects non-plant images (rocks, buildings, animals) - [ ] OOD detection correctly accepts plant images (false rejection rate < 1%) - [ ] All 320+ species → disease head routes resolve correctly - [ ] Large image (12MP) → preprocessed to 224×224 without OOM - [ ] Concurrent 10 requests handled without errors or slowdown - [ ] Degraded mode works if ONNX model fails (falls back to existing TF.js) - [ ] Health endpoint reports model status, last inference time, error count