8.7 KiB
8.7 KiB
Phase 4 — Server Inference Pipeline
Blocked by: Phase 3 (exported ONNX models) Blocks: Phase 5 (hybrid integration) Est. time: 2-3 days Machine: Strix Halo (will serve inference in production)
Objective
Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results.
Architecture
POST /api/identify
│
▼
┌──────────────────────┐
│ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor
│ (sharp + buffer) │ Normalize with ImageNet stats
└──────────┬───────────┘
▼
┌──────────────────────┐
│ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded)
│ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold
└──────────┬───────────┘
▼
┌──────────────────────┐
│ 3. Species Inference │ Run swin-species-int8.onnx
│ (ONNX Runtime) │ softmax over 320 species logits
│ │ Return top-1 species + embedding vector
└──────────┬───────────┘
▼
┌──────────────────────┐
│ 4. Disease Routing │ Look up disease head ONNX for predicted species
│ (species→head) │ Feed embedding through disease head
│ │ softmax over species-conditional disease logits
└──────────┬───────────┘
▼
┌──────────────────────┐
│ 5. Enrichment │ Map class indices → disease/plant objects from
│ (knowledge base) │ src/data/diseases.json and src/data/plants.json
│ │ Return top-K with treatment info
└──────────┬───────────┘
▼
JSON Response
File Structure
src/
├── lib/
│ ├── server/
│ │ ├── inference-server.ts ← Main orchestration
│ │ ├── onnx-loader.ts ← ONNX runtime session manager
│ │ ├── ood-detector.ts ← Mahalanobis OOD detection
│ │ ├── species-classifier.ts ← Species ONNX inference
│ │ ├── disease-classifier.ts ← Disease head routing + inference
│ │ └── image-preprocessor.ts ← Sharp-based preprocessing
│ └── ml/
│ ├── inference.ts ← Existing browser inference (kept as-is)
│ └── ...
└── app/
├── api/
│ ├── identify/route.ts ← Existing endpoint (keep for backward compat)
│ └── identify-v2/route.ts ← New server-side endpoint
└── ...
Key Implementation Details
4.1 ONNX Session Manager
// src/lib/server/onnx-loader.ts
import ort from "onnxruntime-node";
const sessions = new Map<string, ort.InferenceSession>();
export async function getOrCreateSession(path: string): Promise<ort.InferenceSession> {
if (!sessions.has(path)) {
sessions.set(
path,
await ort.InferenceSession.create(path, {
executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo
graphOptimizationLevel: "all",
}),
);
}
return sessions.get(path)!;
}
Execution provider: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part).
4.2 Lazy Loading Strategy
Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them:
const diseaseHeadCache = new Map<string, ort.InferenceSession>();
async function getDiseaseHead(speciesName: string) {
if (!diseaseHeadCache.has(speciesName)) {
const path = `public/models/disease-heads/${speciesName}-int8.onnx`;
diseaseHeadCache.set(speciesName, await createSession(path));
}
return diseaseHeadCache.get(speciesName)!;
}
4.3 Main Inference Pipeline
// src/lib/server/inference-server.ts
export async function serverIdentify(imageBuffer: Buffer): Promise<InferenceResult> {
const start = performance.now();
// 1. Preprocess
const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224]
// 2. OOD detection (quick, using embedding from species model)
const oodResult = await oodDetect(tensor);
if (!oodResult.isPlant) {
return {
error: "No plant detected",
confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold,
inferenceTimeMs: Math.round(performance.now() - start),
};
}
// 3. Species inference
const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx");
const speciesOutput = await speciesSession.run({
input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]),
});
const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array);
const speciesProbs = softmax(speciesLogits);
const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0];
const embedding = speciesOutput.embedding.data as Float32Array;
// 4. Disease inference (routed by species)
const speciesName = speciesIndex[topSpeciesIdx];
const diseaseSession = await getDiseaseHead(speciesName);
const diseaseOutput = await diseaseSession.run({
embedding: new ort.Tensor("float32", embedding, [1, 768]),
});
const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array);
const diseaseProbs = softmax(diseaseLogits);
const topDiseases = topK(diseaseProbs, 5);
// 5. Enrichment
const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases);
return {
species: { id: speciesName, confidence: topSpeciesProb },
diseases: enriched,
oodScore: oodResult.mahalanobisDistance,
inferenceTimeMs: Math.round(performance.now() - start),
};
}
4.4 API Route
// src/app/api/identify-v2/route.ts
export async function POST(req: Request) {
const formData = await req.formData();
const image = formData.get("image") as File;
if (!image || !image.type.startsWith("image/")) {
return Response.json({ error: "Invalid image" }, { status: 400 });
}
const buffer = Buffer.from(await image.arrayBuffer());
const result = await serverIdentify(buffer);
return Response.json(result);
}
4.5 Caching Strategy
- Model sessions: Cache ONNX sessions in memory (warm on first request per deployment)
- Disease heads: Cache top-50 most common species' disease heads (LRU eviction)
- Image preprocessing results: Do NOT cache — each image is unique
- Response caching: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image)
Edge Cases & Gotchas
- Cold start latency: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot.
- Disease head not found: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result.
- Large images: Client may upload 12MP photos. Resize to 224×224 before feeding to ONNX (sharp is fast for this). Set a 10MB upload limit.
- Concurrent requests: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent
run()calls). - Memory: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8).
- Error handling: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode.
Verification
POST /api/identify-v2returns valid JSON with species + disease predictions- Cold start (first ever request): < 3 seconds (model loading)
- Warm requests: < 200ms total (OOD + species + disease + enrichment)
- OOD detection correctly rejects non-plant images (rocks, buildings, animals)
- OOD detection correctly accepts plant images (false rejection rate < 1%)
- All 320+ species → disease head routes resolve correctly
- Large image (12MP) → preprocessed to 224×224 without OOM
- Concurrent 10 requests handled without errors or slowdown
- Degraded mode works if ONNX model fails (falls back to existing TF.js)
- Health endpoint reports model status, last inference time, error count