Files
plant-disease-id/tasks/multi-image-user-feedback/03-hierarchical-model-loader.md

3.3 KiB
Raw Blame History

03. Load Trained Swin-Tiny Model with Species/Disease Routing

meta: id: multi-image-user-feedback-03 feature: multi-image-user-feedback priority: P2 depends_on: [] tags: [ml, model-loader, inference]

objective:

  • Create a new model loader backend that loads the trained Swin-Tiny checkpoint (species_final_final.pt) and routes through the species head and disease heads to produce 11,818-class logits.
  • This task requires the PyTorch model to finish training on the Strix Halo machine and must be exported to the correct format before implementation.

deliverables:

  • src/lib/ml/hierarchical-model.ts — new PlantDiseaseModel implementation for the Swin-Tiny model
  • scripts/export-model.js — script to export the PyTorch checkpoint to TF.js format
  • public/models/plant-disease-classifier-v2/ — exported model directory (TF.js or ONNX)

steps:

  1. Create scripts/export-model.js:

    • Load the PyTorch checkpoint from checkpoints/species_final/species_final_final.pt
    • Export to ONNX format with NCHW input shape [1, 3, 224, 224]
    • Also export species_index.json and class_hierarchy.json alongside the model
    • Output to public/models/plant-disease-classifier-v2/
  2. Create src/lib/ml/hierarchical-model.ts:

    • Implement the PlantDiseaseModel interface
    • Load the ONNX model via onnxruntime-node
    • Load species/disease index files
    • Implement predict():
      • Preprocess to 224×224 (Swin-Tiny input size, not 160)
      • Run forward pass → get [1, 768] features → species logits → disease routing
      • The model checkpoint is a single forward pass that already produces 11,818 logits from the combined species + disease heads
      • Return the full 11,818-dimension logits array
    • Implement getStatus() returning model metadata with numClasses: 11818
  3. Update src/lib/ml/model-loader.ts:

    • Add detection for v2 model directory (model-v2.json or similar)
    • Try loading v2 model first (if available), fall back to v1 then mock
    • Export MODEL_NUM_CLASSES constant for use by other modules
    • Export getModelVersion() to distinguish v1 (38-class) from v2 (11,818-class)
  4. Handle edge cases:

    • No model checkpoint available → fall back through v1 → mock
    • CUDA/ROCm not available for ONNX → use CPU backend
    • Model version mismatch → clear error message

tests:

  • Integration: export model from checkpoint and verify output shape is [1, 11818]
  • Integration: load exported model and run inference on a test image
  • Unit: model loader graceful fallback chain (v2 → v1 → mock)

acceptance_criteria:

  • Exported model produces 11,818 logits from a 224×224 image
  • Model loader loads v2 model when available, falls back gracefully when not
  • All existing v1 model consumers continue to work unmodified (via version detection)

validation:

  • node scripts/export-model.js produces model files
  • npx tsc --noEmit passes
  • POST to /api/identify returns predictions (may be limited if species→disease label mapping not yet complete)

notes:

  • This task is blocked on model training completion. The task file is the implementation spec; actual work begins after species_final_final.pt exists.
  • The ONNX export path is preferred for server-side inference (no Python runtime needed once exported).
  • If ONNX export quality degrades the output, export to TF.js SavedModel format instead.