task to get this here done
This commit is contained in:
211
tasks/hierarchical-model-upgrade/04-server-inference.md
Normal file
211
tasks/hierarchical-model-upgrade/04-server-inference.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# Phase 4 — Server Inference Pipeline
|
||||
|
||||
**Blocked by**: Phase 3 (exported ONNX models)
|
||||
**Blocks**: Phase 5 (hybrid integration)
|
||||
**Est. time**: 2-3 days
|
||||
**Machine**: Strix Halo (will serve inference in production)
|
||||
|
||||
## Objective
|
||||
|
||||
Build the server-side inference API that loads ONNX models, runs OOD detection, predicts species, routes to the correct disease head, and returns enriched results.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
POST /api/identify
|
||||
│
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ 1. Preprocess │ Load image → resize to 224×224 → NCHW tensor
|
||||
│ (sharp + buffer) │ Normalize with ImageNet stats
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ 2. OOD Detection │ Extract embedding from Swin-Tiny (if species model loaded)
|
||||
│ (Mahalanobis) │ Compute Mahalanobis distance → reject if > threshold
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ 3. Species Inference │ Run swin-species-int8.onnx
|
||||
│ (ONNX Runtime) │ softmax over 320 species logits
|
||||
│ │ Return top-1 species + embedding vector
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ 4. Disease Routing │ Look up disease head ONNX for predicted species
|
||||
│ (species→head) │ Feed embedding through disease head
|
||||
│ │ softmax over species-conditional disease logits
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ 5. Enrichment │ Map class indices → disease/plant objects from
|
||||
│ (knowledge base) │ src/data/diseases.json and src/data/plants.json
|
||||
│ │ Return top-K with treatment info
|
||||
└──────────┬───────────┘
|
||||
▼
|
||||
JSON Response
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── lib/
|
||||
│ ├── server/
|
||||
│ │ ├── inference-server.ts ← Main orchestration
|
||||
│ │ ├── onnx-loader.ts ← ONNX runtime session manager
|
||||
│ │ ├── ood-detector.ts ← Mahalanobis OOD detection
|
||||
│ │ ├── species-classifier.ts ← Species ONNX inference
|
||||
│ │ ├── disease-classifier.ts ← Disease head routing + inference
|
||||
│ │ └── image-preprocessor.ts ← Sharp-based preprocessing
|
||||
│ └── ml/
|
||||
│ ├── inference.ts ← Existing browser inference (kept as-is)
|
||||
│ └── ...
|
||||
└── app/
|
||||
├── api/
|
||||
│ ├── identify/route.ts ← Existing endpoint (keep for backward compat)
|
||||
│ └── identify-v2/route.ts ← New server-side endpoint
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Key Implementation Details
|
||||
|
||||
### 4.1 ONNX Session Manager
|
||||
|
||||
```typescript
|
||||
// src/lib/server/onnx-loader.ts
|
||||
import ort from "onnxruntime-node";
|
||||
|
||||
const sessions = new Map<string, ort.InferenceSession>();
|
||||
|
||||
export async function getOrCreateSession(path: string): Promise<ort.InferenceSession> {
|
||||
if (!sessions.has(path)) {
|
||||
sessions.set(
|
||||
path,
|
||||
await ort.InferenceSession.create(path, {
|
||||
executionProviders: ["cpu"], // or ['rocm', 'cpu'] on Strix Halo
|
||||
graphOptimizationLevel: "all",
|
||||
}),
|
||||
);
|
||||
}
|
||||
return sessions.get(path)!;
|
||||
}
|
||||
```
|
||||
|
||||
**Execution provider**: Start with CPU (ONNX Runtime's CPU path is well-optimized for INT8). If ROCm-specific providers (MIGraphX, DirectML) are available on Strix Halo, test GPU execution for the species model (the compute-heavy part).
|
||||
|
||||
### 4.2 Lazy Loading Strategy
|
||||
|
||||
Don't load all 320+ disease heads on startup. Load them lazily on first request for each species and cache them:
|
||||
|
||||
```typescript
|
||||
const diseaseHeadCache = new Map<string, ort.InferenceSession>();
|
||||
|
||||
async function getDiseaseHead(speciesName: string) {
|
||||
if (!diseaseHeadCache.has(speciesName)) {
|
||||
const path = `public/models/disease-heads/${speciesName}-int8.onnx`;
|
||||
diseaseHeadCache.set(speciesName, await createSession(path));
|
||||
}
|
||||
return diseaseHeadCache.get(speciesName)!;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 Main Inference Pipeline
|
||||
|
||||
```typescript
|
||||
// src/lib/server/inference-server.ts
|
||||
export async function serverIdentify(imageBuffer: Buffer): Promise<InferenceResult> {
|
||||
const start = performance.now();
|
||||
|
||||
// 1. Preprocess
|
||||
const tensor = await preprocessImage(imageBuffer); // Float32Array [1,3,224,224]
|
||||
|
||||
// 2. OOD detection (quick, using embedding from species model)
|
||||
const oodResult = await oodDetect(tensor);
|
||||
if (!oodResult.isPlant) {
|
||||
return {
|
||||
error: "No plant detected",
|
||||
confidence: 1 - oodResult.mahalanobisDistance / oodResult.threshold,
|
||||
inferenceTimeMs: Math.round(performance.now() - start),
|
||||
};
|
||||
}
|
||||
|
||||
// 3. Species inference
|
||||
const speciesSession = await getOrCreateSession("public/models/swin-species-int8.onnx");
|
||||
const speciesOutput = await speciesSession.run({
|
||||
input: new ort.Tensor("float32", tensor, [1, 3, 224, 224]),
|
||||
});
|
||||
const speciesLogits = Array.from(speciesOutput.species_logits.data as Float32Array);
|
||||
const speciesProbs = softmax(speciesLogits);
|
||||
const [topSpeciesIdx, topSpeciesProb] = topK(speciesProbs, 1)[0];
|
||||
const embedding = speciesOutput.embedding.data as Float32Array;
|
||||
|
||||
// 4. Disease inference (routed by species)
|
||||
const speciesName = speciesIndex[topSpeciesIdx];
|
||||
const diseaseSession = await getDiseaseHead(speciesName);
|
||||
const diseaseOutput = await diseaseSession.run({
|
||||
embedding: new ort.Tensor("float32", embedding, [1, 768]),
|
||||
});
|
||||
const diseaseLogits = Array.from(diseaseOutput.disease_logits.data as Float32Array);
|
||||
const diseaseProbs = softmax(diseaseLogits);
|
||||
const topDiseases = topK(diseaseProbs, 5);
|
||||
|
||||
// 5. Enrichment
|
||||
const enriched = enrichResults(topSpeciesIdx, speciesName, topDiseases);
|
||||
|
||||
return {
|
||||
species: { id: speciesName, confidence: topSpeciesProb },
|
||||
diseases: enriched,
|
||||
oodScore: oodResult.mahalanobisDistance,
|
||||
inferenceTimeMs: Math.round(performance.now() - start),
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 API Route
|
||||
|
||||
```typescript
|
||||
// src/app/api/identify-v2/route.ts
|
||||
export async function POST(req: Request) {
|
||||
const formData = await req.formData();
|
||||
const image = formData.get("image") as File;
|
||||
|
||||
if (!image || !image.type.startsWith("image/")) {
|
||||
return Response.json({ error: "Invalid image" }, { status: 400 });
|
||||
}
|
||||
|
||||
const buffer = Buffer.from(await image.arrayBuffer());
|
||||
const result = await serverIdentify(buffer);
|
||||
|
||||
return Response.json(result);
|
||||
}
|
||||
```
|
||||
|
||||
### 4.5 Caching Strategy
|
||||
|
||||
- **Model sessions**: Cache ONNX sessions in memory (warm on first request per deployment)
|
||||
- **Disease heads**: Cache top-50 most common species' disease heads (LRU eviction)
|
||||
- **Image preprocessing results**: Do NOT cache — each image is unique
|
||||
- **Response caching**: Optionally cache identical responses for 5 minutes (hash of image buffer, for repeated uploads of same image)
|
||||
|
||||
## Edge Cases & Gotchas
|
||||
|
||||
- **Cold start latency**: First request loads the species model + OOD detector (~500ms). Subsequent requests are <200ms. Consider pre-warming on server boot.
|
||||
- **Disease head not found**: If the species is predicted but no disease head ONNX exists (e.g., new species not in training), fall back to a "general" disease head or return species-only result.
|
||||
- **Large images**: Client may upload 12MP photos. Resize to 224×224 _before_ feeding to ONNX (sharp is fast for this). Set a 10MB upload limit.
|
||||
- **Concurrent requests**: ONNX Runtime sessions are thread-safe. Use a connection pool or queue for the species model (1 session handles concurrent `run()` calls).
|
||||
- **Memory**: 320 disease heads at ~100KB each = 32MB total if all cached. Acceptable. Species model is ~1.1MB (INT8).
|
||||
- **Error handling**: If ONNX inference fails, fall back to the existing browser-style TF.js model as a degraded mode.
|
||||
|
||||
## Verification
|
||||
|
||||
- [ ] `POST /api/identify-v2` returns valid JSON with species + disease predictions
|
||||
- [ ] Cold start (first ever request): < 3 seconds (model loading)
|
||||
- [ ] Warm requests: < 200ms total (OOD + species + disease + enrichment)
|
||||
- [ ] OOD detection correctly rejects non-plant images (rocks, buildings, animals)
|
||||
- [ ] OOD detection correctly accepts plant images (false rejection rate < 1%)
|
||||
- [ ] All 320+ species → disease head routes resolve correctly
|
||||
- [ ] Large image (12MP) → preprocessed to 224×224 without OOM
|
||||
- [ ] Concurrent 10 requests handled without errors or slowdown
|
||||
- [ ] Degraded mode works if ONNX model fails (falls back to existing TF.js)
|
||||
- [ ] Health endpoint reports model status, last inference time, error count
|
||||
Reference in New Issue
Block a user