Files
plant-disease-id/tasks/multi-image-user-feedback/02-multi-image-inference-pipeline.md

4.0 KiB

02. Multi-Image Ensemble & Species-Constrained Inference

meta: id: multi-image-user-feedback-02 feature: multi-image-user-feedback priority: P1 depends_on: [multi-image-user-feedback-01] tags: [inference, ml]

objective:

  • Extend the inference pipeline to support multi-image ensemble inference (averaging features or logits from 2+ images).
  • Add species-constrained softmax that renormalizes probabilities over only the disease classes belonging to a known species.

deliverables:

  • src/lib/ml/inference.ts — updated with ensemble and constrained inference functions
  • src/lib/ml/confidence.ts — updated with species-aware confidence calibration

steps:

  1. In src/lib/ml/inference.ts, add:

    • runEnsembleInference(tensors: Float32Array[], topK?: number): Promise<InferenceResult> — runs multiple images through the model, averages their logits, and returns top-K predictions. Averaging logits (before softmax) is preferred over averaging probabilities since it preserves confidence structure.
    • speciesConstrainedSoftmax(logits: Float32Array, speciesClasses: number[]): Float32Array — given the full 11,818-class logits and a list of class indices belonging to the user-specified species, compute softmax over only those indices and return a renormalized probability vector (zero everywhere else). The model output dimension (11,818) should be a configurable constant.
    • runSpeciesConstrainedInference(tensor: Float32Array, speciesClassIndices: number[], topK?: number): Promise<InferenceResult> — run inference then apply species-constrained softmax before extracting top-K.
    • runEnsembleSpeciesConstrained(tensors: Float32Array[], speciesClassIndices: number[], topK?: number): Promise<InferenceResult> — ensemble then constrain.
  2. Export CLASSIFIER_NUM_CLASSES constant (11,818) and SPECIES_CLASS_RANGES (a map from species name → [startIndex, endIndex] in the model output) from a new constants file or from labels.ts.

  3. In src/lib/ml/confidence.ts, add:

    • calibrateSpeciesConfidence(rawProb: number, numDiseaseClasses: number): ConfidenceResult — adjusts calibration factor based on how many disease classes the species has (fewer classes = higher effective confidence).
    • getEnsembleConfidence(predictions: RawPrediction[][]): ConfidenceResult — aggregate confidence from multiple images.
  4. Create src/lib/ml/species-class-ranges.ts containing the mapping from species name → [class start index, class end index] in the 11,818-class model output. This is derived from the training dataset's species_index.json or class_hierarchy.json.

  5. Handle edge cases:

    • If tensors array is empty → throw
    • If tensor length doesn't match expected model input → throw validation error
    • If species name not found in SPECIES_CLASS_RANGES → fall back to full softmax

tests:

  • Unit: test logit averaging with 2 identical tensors → results should be identical to single inference
  • Unit: test logit averaging with 2 different tensors → verify averaged output
  • Unit: test species-constrained softmax — verify probabilities are zero outside the constrained indices
  • Unit: test constrained softmax sums to ~1.0 within the species class range
  • Unit: test ensemble + constrained combined pipeline

acceptance_criteria:

  • runEnsembleInference accepts multiple tensors and returns averaged top-K predictions
  • speciesConstrainedSoftmax zeros out all classes outside the species range
  • runSpeciesConstrainedInference and runEnsembleSpeciesConstrained produce constrained results
  • Confidence calibration accounts for number of disease classes in the species

validation:

  • npx tsc --noEmit passes
  • Unit tests pass with npx vitest run src/lib/ml/ --reporter=verbose

notes:

  • The current mock model outputs 38 classes. These new functions target the 11,818-class model.
  • Until the real model loads, ensemble/constrained functions should still work with mock data (just with fewer classes).
  • The species-ranges file should be auto-generated from data/organized/class_hierarchy.json and checked into version control.