pre torch.compile -chkpoint made
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
# 03. Load Trained Swin-Tiny Model with Species/Disease Routing
|
||||
|
||||
meta:
|
||||
id: multi-image-user-feedback-03
|
||||
feature: multi-image-user-feedback
|
||||
priority: P2
|
||||
depends_on: []
|
||||
tags: [ml, model-loader, inference]
|
||||
|
||||
objective:
|
||||
|
||||
- Create a new model loader backend that loads the trained Swin-Tiny checkpoint (`species_final_final.pt`) and routes through the species head and disease heads to produce 11,818-class logits.
|
||||
- This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation.
|
||||
|
||||
deliverables:
|
||||
|
||||
- `src/lib/ml/hierarchical-model.ts` — new PlantDiseaseModel implementation for the Swin-Tiny model
|
||||
- `scripts/export-model.js` — script to export the PyTorch checkpoint to TF.js format
|
||||
- `public/models/plant-disease-classifier-v2/` — exported model directory (TF.js or ONNX)
|
||||
|
||||
steps:
|
||||
|
||||
1. Create `scripts/export-model.js`:
|
||||
- Load the PyTorch checkpoint from `checkpoints/species_final/species_final_final.pt`
|
||||
- Export to ONNX format with NCHW input shape [1, 3, 224, 224]
|
||||
- Also export `species_index.json` and `class_hierarchy.json` alongside the model
|
||||
- Output to `public/models/plant-disease-classifier-v2/`
|
||||
|
||||
2. Create `src/lib/ml/hierarchical-model.ts`:
|
||||
- Implement the `PlantDiseaseModel` interface
|
||||
- Load the ONNX model via `onnxruntime-node`
|
||||
- Load species/disease index files
|
||||
- Implement `predict()`:
|
||||
- Preprocess to 224×224 (Swin-Tiny input size, not 160)
|
||||
- Run forward pass → get [1, 768] features → species logits → disease routing
|
||||
- The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads
|
||||
- Return the full 11,818-dimension logits array
|
||||
- Implement `getStatus()` returning model metadata with `numClasses: 11818`
|
||||
|
||||
3. Update `src/lib/ml/model-loader.ts`:
|
||||
- Add detection for v2 model directory (`model-v2.json` or similar)
|
||||
- Try loading v2 model first (if available), fall back to v1 then mock
|
||||
- Export `MODEL_NUM_CLASSES` constant for use by other modules
|
||||
- Export `getModelVersion()` to distinguish v1 (38-class) from v2 (11,818-class)
|
||||
|
||||
4. Handle edge cases:
|
||||
- No model checkpoint available → fall back through v1 → mock
|
||||
- CUDA/ROCm not available for ONNX → use CPU backend
|
||||
- Model version mismatch → clear error message
|
||||
|
||||
tests:
|
||||
|
||||
- Integration: export model from checkpoint and verify output shape is [1, 11818]
|
||||
- Integration: load exported model and run inference on a test image
|
||||
- Unit: model loader graceful fallback chain (v2 → v1 → mock)
|
||||
|
||||
acceptance_criteria:
|
||||
|
||||
- Exported model produces 11,818 logits from a 224×224 image
|
||||
- Model loader loads v2 model when available, falls back gracefully when not
|
||||
- All existing v1 model consumers continue to work unmodified (via version detection)
|
||||
|
||||
validation:
|
||||
|
||||
- `node scripts/export-model.js` produces model files
|
||||
- `npx tsc --noEmit` passes
|
||||
- POST to `/api/identify` returns predictions (may be limited if species→disease label mapping not yet complete)
|
||||
|
||||
notes:
|
||||
|
||||
- This task is **blocked on model training completion**. The task file is the implementation spec; actual work begins after `species_final_final.pt` exists.
|
||||
- The ONNX export path is preferred for server-side inference (no Python runtime needed once exported).
|
||||
- If ONNX export quality degrades the output, export to TF.js SavedModel format instead.
|
||||
Reference in New Issue
Block a user