Files
plant-disease-id/tasks/multi-image-user-feedback/03-hierarchical-model-loader.md

74 lines
3.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 03. Load Trained Swin-Tiny Model with Species/Disease Routing
meta:
id: multi-image-user-feedback-03
feature: multi-image-user-feedback
priority: P2
depends_on: []
tags: [ml, model-loader, inference]
objective:
- Create a new model loader backend that loads the trained Swin-Tiny checkpoint (`species_final_final.pt`) and routes through the species head and disease heads to produce 11,818-class logits.
- This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation.
deliverables:
- `src/lib/ml/hierarchical-model.ts` — new PlantDiseaseModel implementation for the Swin-Tiny model
- `scripts/export-model.js` — script to export the PyTorch checkpoint to TF.js format
- `public/models/plant-disease-classifier-v2/` — exported model directory (TF.js or ONNX)
steps:
1. Create `scripts/export-model.js`:
- Load the PyTorch checkpoint from `checkpoints/species_final/species_final_final.pt`
- Export to ONNX format with NCHW input shape [1, 3, 224, 224]
- Also export `species_index.json` and `class_hierarchy.json` alongside the model
- Output to `public/models/plant-disease-classifier-v2/`
2. Create `src/lib/ml/hierarchical-model.ts`:
- Implement the `PlantDiseaseModel` interface
- Load the ONNX model via `onnxruntime-node`
- Load species/disease index files
- Implement `predict()`:
- Preprocess to 224×224 (Swin-Tiny input size, not 160)
- Run forward pass → get [1, 768] features → species logits → disease routing
- The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads
- Return the full 11,818-dimension logits array
- Implement `getStatus()` returning model metadata with `numClasses: 11818`
3. Update `src/lib/ml/model-loader.ts`:
- Add detection for v2 model directory (`model-v2.json` or similar)
- Try loading v2 model first (if available), fall back to v1 then mock
- Export `MODEL_NUM_CLASSES` constant for use by other modules
- Export `getModelVersion()` to distinguish v1 (38-class) from v2 (11,818-class)
4. Handle edge cases:
- No model checkpoint available → fall back through v1 → mock
- CUDA/ROCm not available for ONNX → use CPU backend
- Model version mismatch → clear error message
tests:
- Integration: export model from checkpoint and verify output shape is [1, 11818]
- Integration: load exported model and run inference on a test image
- Unit: model loader graceful fallback chain (v2 → v1 → mock)
acceptance_criteria:
- Exported model produces 11,818 logits from a 224×224 image
- Model loader loads v2 model when available, falls back gracefully when not
- All existing v1 model consumers continue to work unmodified (via version detection)
validation:
- `node scripts/export-model.js` produces model files
- `npx tsc --noEmit` passes
- POST to `/api/identify` returns predictions (may be limited if species→disease label mapping not yet complete)
notes:
- This task is **blocked on model training completion**. The task file is the implementation spec; actual work begins after `species_final_final.pt` exists.
- The ONNX export path is preferred for server-side inference (no Python runtime needed once exported).
- If ONNX export quality degrades the output, export to TF.js SavedModel format instead.