pre torch.compile -chkpoint made

This commit is contained in:
2026-06-16 10:40:38 -05:00
parent 34855eff55
commit 6650d3c5ea
19 changed files with 2519 additions and 0 deletions

View File

@@ -0,0 +1,73 @@
# 03. Load Trained Swin-Tiny Model with Species/Disease Routing
meta:
id: multi-image-user-feedback-03
feature: multi-image-user-feedback
priority: P2
depends_on: []
tags: [ml, model-loader, inference]
objective:
- Create a new model loader backend that loads the trained Swin-Tiny checkpoint (`species_final_final.pt`) and routes through the species head and disease heads to produce 11,818-class logits.
- This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation.
deliverables:
- `src/lib/ml/hierarchical-model.ts` — new PlantDiseaseModel implementation for the Swin-Tiny model
- `scripts/export-model.js` — script to export the PyTorch checkpoint to TF.js format
- `public/models/plant-disease-classifier-v2/` — exported model directory (TF.js or ONNX)
steps:
1. Create `scripts/export-model.js`:
- Load the PyTorch checkpoint from `checkpoints/species_final/species_final_final.pt`
- Export to ONNX format with NCHW input shape [1, 3, 224, 224]
- Also export `species_index.json` and `class_hierarchy.json` alongside the model
- Output to `public/models/plant-disease-classifier-v2/`
2. Create `src/lib/ml/hierarchical-model.ts`:
- Implement the `PlantDiseaseModel` interface
- Load the ONNX model via `onnxruntime-node`
- Load species/disease index files
- Implement `predict()`:
- Preprocess to 224×224 (Swin-Tiny input size, not 160)
- Run forward pass → get [1, 768] features → species logits → disease routing
- The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads
- Return the full 11,818-dimension logits array
- Implement `getStatus()` returning model metadata with `numClasses: 11818`
3. Update `src/lib/ml/model-loader.ts`:
- Add detection for v2 model directory (`model-v2.json` or similar)
- Try loading v2 model first (if available), fall back to v1 then mock
- Export `MODEL_NUM_CLASSES` constant for use by other modules
- Export `getModelVersion()` to distinguish v1 (38-class) from v2 (11,818-class)
4. Handle edge cases:
- No model checkpoint available → fall back through v1 → mock
- CUDA/ROCm not available for ONNX → use CPU backend
- Model version mismatch → clear error message
tests:
- Integration: export model from checkpoint and verify output shape is [1, 11818]
- Integration: load exported model and run inference on a test image
- Unit: model loader graceful fallback chain (v2 → v1 → mock)
acceptance_criteria:
- Exported model produces 11,818 logits from a 224×224 image
- Model loader loads v2 model when available, falls back gracefully when not
- All existing v1 model consumers continue to work unmodified (via version detection)
validation:
- `node scripts/export-model.js` produces model files
- `npx tsc --noEmit` passes
- POST to `/api/identify` returns predictions (may be limited if species→disease label mapping not yet complete)
notes:
- This task is **blocked on model training completion**. The task file is the implementation spec; actual work begins after `species_final_final.pt` exists.
- The ONNX export path is preferred for server-side inference (no Python runtime needed once exported).
- If ONNX export quality degrades the output, export to TF.js SavedModel format instead.