3.3 KiB
3.3 KiB
03. Load Trained Swin-Tiny Model with Species/Disease Routing
meta: id: multi-image-user-feedback-03 feature: multi-image-user-feedback priority: P2 depends_on: [] tags: [ml, model-loader, inference]
objective:
- Create a new model loader backend that loads the trained Swin-Tiny checkpoint (
species_final_final.pt) and routes through the species head and disease heads to produce 11,818-class logits. - This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation.
deliverables:
src/lib/ml/hierarchical-model.ts— new PlantDiseaseModel implementation for the Swin-Tiny modelscripts/export-model.js— script to export the PyTorch checkpoint to TF.js formatpublic/models/plant-disease-classifier-v2/— exported model directory (TF.js or ONNX)
steps:
-
Create
scripts/export-model.js:- Load the PyTorch checkpoint from
checkpoints/species_final/species_final_final.pt - Export to ONNX format with NCHW input shape [1, 3, 224, 224]
- Also export
species_index.jsonandclass_hierarchy.jsonalongside the model - Output to
public/models/plant-disease-classifier-v2/
- Load the PyTorch checkpoint from
-
Create
src/lib/ml/hierarchical-model.ts:- Implement the
PlantDiseaseModelinterface - Load the ONNX model via
onnxruntime-node - Load species/disease index files
- Implement
predict():- Preprocess to 224×224 (Swin-Tiny input size, not 160)
- Run forward pass → get [1, 768] features → species logits → disease routing
- The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads
- Return the full 11,818-dimension logits array
- Implement
getStatus()returning model metadata withnumClasses: 11818
- Implement the
-
Update
src/lib/ml/model-loader.ts:- Add detection for v2 model directory (
model-v2.jsonor similar) - Try loading v2 model first (if available), fall back to v1 then mock
- Export
MODEL_NUM_CLASSESconstant for use by other modules - Export
getModelVersion()to distinguish v1 (38-class) from v2 (11,818-class)
- Add detection for v2 model directory (
-
Handle edge cases:
- No model checkpoint available → fall back through v1 → mock
- CUDA/ROCm not available for ONNX → use CPU backend
- Model version mismatch → clear error message
tests:
- Integration: export model from checkpoint and verify output shape is [1, 11818]
- Integration: load exported model and run inference on a test image
- Unit: model loader graceful fallback chain (v2 → v1 → mock)
acceptance_criteria:
- Exported model produces 11,818 logits from a 224×224 image
- Model loader loads v2 model when available, falls back gracefully when not
- All existing v1 model consumers continue to work unmodified (via version detection)
validation:
node scripts/export-model.jsproduces model filesnpx tsc --noEmitpasses- POST to
/api/identifyreturns predictions (may be limited if species→disease label mapping not yet complete)
notes:
- This task is blocked on model training completion. The task file is the implementation spec; actual work begins after
species_final_final.ptexists. - The ONNX export path is preferred for server-side inference (no Python runtime needed once exported).
- If ONNX export quality degrades the output, export to TF.js SavedModel format instead.