Files
plant-disease-id/tasks/hierarchical-model-upgrade/04-server-inference.md

212 lines
8.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Phase 4 — Server Inference Pipeline
**Blocked by**: Phase 3 (exported ONNX models)
**Blocks**: Phase 5 (hybrid integration)
**Est. time**: 2-3 days
**Machine**: Strix Halo (will serve inference in production)
## Objective
Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results.
## Architecture
```
POST /api/identify
┌──────────────────────┐
│ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor
│ (sharp + buffer) │ Normalize with ImageNet stats
└──────────┬───────────┘
┌──────────────────────┐
│ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded)
│ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold
└──────────┬───────────┘
┌──────────────────────┐
│ 3. Species Inference │ Run swin-species-int8.onnx
│ (ONNX Runtime) │ softmax over 320 species logits
│ │ Return top-1 species + embedding vector
└──────────┬───────────┘
┌──────────────────────┐
│ 4. Disease Routing │ Look up disease head ONNX for predicted species
│ (species→head) │ Feed embedding through disease head
│ │ softmax over species-conditional disease logits
└──────────┬───────────┘
┌──────────────────────┐
│ 5. Enrichment │ Map class indices → disease/plant objects from
│ (knowledge base) │ src/data/diseases.json and src/data/plants.json
│ │ Return top-K with treatment info
└──────────┬───────────┘
JSON Response
```
## File Structure
```
src/
├── lib/
│ ├── server/
│ │ ├── inference-server.ts ← Main orchestration
│ │ ├── onnx-loader.ts ← ONNX runtime session manager
│ │ ├── ood-detector.ts ← Mahalanobis OOD detection
│ │ ├── species-classifier.ts ← Species ONNX inference
│ │ ├── disease-classifier.ts ← Disease head routing + inference
│ │ └── image-preprocessor.ts ← Sharp-based preprocessing
│ └── ml/
│ ├── inference.ts ← Existing browser inference (kept as-is)
│ └── ...
└── app/
├── api/
│ ├── identify/route.ts ← Existing endpoint (keep for backward compat)
│ └── identify-v2/route.ts ← New server-side endpoint
└── ...
```
## Key Implementation Details
### 4.1 ONNX Session Manager
```typescript
// src/lib/server/onnx-loader.ts
import ort from "onnxruntime-node";
const sessions = new Map<string, ort.InferenceSession>();
export async function getOrCreateSession(path: string): Promise<ort.InferenceSession> {
if (!sessions.has(path)) {
sessions.set(
path,
await ort.InferenceSession.create(path, {
executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo
graphOptimizationLevel: "all",
}),
);
}
return sessions.get(path)!;
}
```
**Execution provider**: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part).
### 4.2 Lazy Loading Strategy
Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them:
```typescript
const diseaseHeadCache = new Map<string, ort.InferenceSession>();
async function getDiseaseHead(speciesName: string) {
if (!diseaseHeadCache.has(speciesName)) {
const path = `public/models/disease-heads/${speciesName}-int8.onnx`;
diseaseHeadCache.set(speciesName, await createSession(path));
}
return diseaseHeadCache.get(speciesName)!;
}
```
### 4.3 Main Inference Pipeline
```typescript
// src/lib/server/inference-server.ts
export async function serverIdentify(imageBuffer: Buffer): Promise<InferenceResult> {
const start = performance.now();
// 1. Preprocess
const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224]
// 2. OOD detection (quick, using embedding from species model)
const oodResult = await oodDetect(tensor);
if (!oodResult.isPlant) {
return {
error: "No plant detected",
confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold,
inferenceTimeMs: Math.round(performance.now() - start),
};
}
// 3. Species inference
const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx");
const speciesOutput = await speciesSession.run({
input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]),
});
const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array);
const speciesProbs = softmax(speciesLogits);
const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0];
const embedding = speciesOutput.embedding.data as Float32Array;
// 4. Disease inference (routed by species)
const speciesName = speciesIndex[topSpeciesIdx];
const diseaseSession = await getDiseaseHead(speciesName);
const diseaseOutput = await diseaseSession.run({
embedding: new ort.Tensor("float32", embedding, [1, 768]),
});
const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array);
const diseaseProbs = softmax(diseaseLogits);
const topDiseases = topK(diseaseProbs, 5);
// 5. Enrichment
const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases);
return {
species: { id: speciesName, confidence: topSpeciesProb },
diseases: enriched,
oodScore: oodResult.mahalanobisDistance,
inferenceTimeMs: Math.round(performance.now() - start),
};
}
```
### 4.4 API Route
```typescript
// src/app/api/identify-v2/route.ts
export async function POST(req: Request) {
const formData = await req.formData();
const image = formData.get("image") as File;
if (!image || !image.type.startsWith("image/")) {
return Response.json({ error: "Invalid image" }, { status: 400 });
}
const buffer = Buffer.from(await image.arrayBuffer());
const result = await serverIdentify(buffer);
return Response.json(result);
}
```
### 4.5 Caching Strategy
- **Model sessions**: Cache ONNX sessions in memory (warm on first request per deployment)
- **Disease heads**: Cache top-50 most common species' disease heads (LRU eviction)
- **Image preprocessing results**: Do NOT cache — each image is unique
- **Response caching**: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image)
## Edge Cases & Gotchas
- **Cold start latency**: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot.
- **Disease head not found**: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result.
- **Large images**: Client may upload 12MP photos. Resize to 224×224 _before_ feeding to ONNX (sharp is fast for this). Set a 10MB upload limit.
- **Concurrent requests**: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent `run()` calls).
- **Memory**: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8).
- **Error handling**: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode.
## Verification
- [ ] `POST /api/identify-v2` returns valid JSON with species + disease predictions
- [ ] Cold start (first ever request): < 3 seconds (model loading)
- [ ] Warm requests: < 200ms total (OOD + species + disease + enrichment)
- [ ] OOD detection correctly rejects non-plant images (rocks, buildings, animals)
- [ ] OOD detection correctly accepts plant images (false rejection rate < 1%)
- [ ] All 320+ species → disease head routes resolve correctly
- [ ] Large image (12MP) → preprocessed to 224×224 without OOM
- [ ] Concurrent 10 requests handled without errors or slowdown
- [ ] Degraded mode works if ONNX model fails (falls back to existing TF.js)
- [ ] Health endpoint reports model status, last inference time, error count