Files
plant-disease-id/tasks/hierarchical-model-upgrade/04-server-inference.md

8.7 KiB
Raw Blame History

Phase 4 — Server Inference Pipeline

Blocked by: Phase 3 (exported ONNX models) Blocks: Phase 5 (hybrid integration) Est. time: 2-3 days Machine: Strix Halo (will serve inference in production)

Objective

Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results.

Architecture

POST /api/identify
     │
     ▼
┌──────────────────────┐
│ 1. Preprocess        │  Load image → resize to 224×224 → NCHW tensor
│    (sharp + buffer)  │  Normalize with ImageNet stats
└──────────┬───────────┘
           ▼
┌──────────────────────┐
│ 2. OOD Detection     │  Extract embedding from Swin-Tiny (if species model loaded)
│    (Mahalanobis)     │  Compute Mahalanobis distance → reject if > threshold
└──────────┬───────────┘
           ▼
┌──────────────────────┐
│ 3. Species Inference │  Run swin-species-int8.onnx
│    (ONNX Runtime)    │  softmax over 320 species logits
│                      │  Return top-1 species + embedding vector
└──────────┬───────────┘
           ▼
┌──────────────────────┐
│ 4. Disease Routing   │  Look up disease head ONNX for predicted species
│    (species→head)    │  Feed embedding through disease head
│                      │  softmax over species-conditional disease logits
└──────────┬───────────┘
           ▼
┌──────────────────────┐
│ 5. Enrichment        │  Map class indices → disease/plant objects from
│    (knowledge base)  │  src/data/diseases.json and src/data/plants.json
│                      │  Return top-K with treatment info
└──────────┬───────────┘
           ▼
     JSON Response

File Structure

src/
├── lib/
│   ├── server/
│   │   ├── inference-server.ts      ← Main orchestration
│   │   ├── onnx-loader.ts           ← ONNX runtime session manager
│   │   ├── ood-detector.ts          ← Mahalanobis OOD detection
│   │   ├── species-classifier.ts    ← Species ONNX inference
│   │   ├── disease-classifier.ts    ← Disease head routing + inference
│   │   └── image-preprocessor.ts    ← Sharp-based preprocessing
│   └── ml/
│       ├── inference.ts             ← Existing browser inference (kept as-is)
│       └── ...
└── app/
    ├── api/
    │   ├── identify/route.ts        ← Existing endpoint (keep for backward compat)
    │   └── identify-v2/route.ts     ← New server-side endpoint
    └── ...

Key Implementation Details

4.1 ONNX Session Manager

// src/lib/server/onnx-loader.ts
import ort from "onnxruntime-node";

const sessions = new Map<string, ort.InferenceSession>();

export async function getOrCreateSession(path: string): Promise<ort.InferenceSession> {
  if (!sessions.has(path)) {
    sessions.set(
      path,
      await ort.InferenceSession.create(path, {
        executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo
        graphOptimizationLevel: "all",
      }),
    );
  }
  return sessions.get(path)!;
}

Execution provider: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part).

4.2 Lazy Loading Strategy

Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them:

const diseaseHeadCache = new Map<string, ort.InferenceSession>();

async function getDiseaseHead(speciesName: string) {
  if (!diseaseHeadCache.has(speciesName)) {
    const path = `public/models/disease-heads/${speciesName}-int8.onnx`;
    diseaseHeadCache.set(speciesName, await createSession(path));
  }
  return diseaseHeadCache.get(speciesName)!;
}

4.3 Main Inference Pipeline

// src/lib/server/inference-server.ts
export async function serverIdentify(imageBuffer: Buffer): Promise<InferenceResult> {
  const start = performance.now();

  // 1. Preprocess
  const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224]

  // 2. OOD detection (quick, using embedding from species model)
  const oodResult = await oodDetect(tensor);
  if (!oodResult.isPlant) {
    return {
      error: "No plant detected",
      confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold,
      inferenceTimeMs: Math.round(performance.now() - start),
    };
  }

  // 3. Species inference
  const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx");
  const speciesOutput = await speciesSession.run({
    input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]),
  });
  const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array);
  const speciesProbs = softmax(speciesLogits);
  const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0];
  const embedding = speciesOutput.embedding.data as Float32Array;

  // 4. Disease inference (routed by species)
  const speciesName = speciesIndex[topSpeciesIdx];
  const diseaseSession = await getDiseaseHead(speciesName);
  const diseaseOutput = await diseaseSession.run({
    embedding: new ort.Tensor("float32", embedding, [1, 768]),
  });
  const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array);
  const diseaseProbs = softmax(diseaseLogits);
  const topDiseases = topK(diseaseProbs, 5);

  // 5. Enrichment
  const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases);

  return {
    species: { id: speciesName, confidence: topSpeciesProb },
    diseases: enriched,
    oodScore: oodResult.mahalanobisDistance,
    inferenceTimeMs: Math.round(performance.now() - start),
  };
}

4.4 API Route

// src/app/api/identify-v2/route.ts
export async function POST(req: Request) {
  const formData = await req.formData();
  const image = formData.get("image") as File;

  if (!image || !image.type.startsWith("image/")) {
    return Response.json({ error: "Invalid image" }, { status: 400 });
  }

  const buffer = Buffer.from(await image.arrayBuffer());
  const result = await serverIdentify(buffer);

  return Response.json(result);
}

4.5 Caching Strategy

  • Model sessions: Cache ONNX sessions in memory (warm on first request per deployment)
  • Disease heads: Cache top-50 most common species' disease heads (LRU eviction)
  • Image preprocessing results: Do NOT cache — each image is unique
  • Response caching: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image)

Edge Cases & Gotchas

  • Cold start latency: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot.
  • Disease head not found: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result.
  • Large images: Client may upload 12MP photos. Resize to 224×224 before feeding to ONNX (sharp is fast for this). Set a 10MB upload limit.
  • Concurrent requests: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent run() calls).
  • Memory: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8).
  • Error handling: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode.

Verification

  • POST /api/identify-v2 returns valid JSON with species + disease predictions
  • Cold start (first ever request): < 3 seconds (model loading)
  • Warm requests: < 200ms total (OOD + species + disease + enrichment)
  • OOD detection correctly rejects non-plant images (rocks, buildings, animals)
  • OOD detection correctly accepts plant images (false rejection rate < 1%)
  • All 320+ species → disease head routes resolve correctly
  • Large image (12MP) → preprocessed to 224×224 without OOM
  • Concurrent 10 requests handled without errors or slowdown
  • Degraded mode works if ONNX model fails (falls back to existing TF.js)
  • Health endpoint reports model status, last inference time, error count