pre torch.compile -chkpoint made
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
# 02. Multi-Image Ensemble & Species-Constrained Inference
|
||||
|
||||
meta:
|
||||
id: multi-image-user-feedback-02
|
||||
feature: multi-image-user-feedback
|
||||
priority: P1
|
||||
depends_on: [multi-image-user-feedback-01]
|
||||
tags: [inference, ml]
|
||||
|
||||
objective:
|
||||
|
||||
- Extend the inference pipeline to support multi-image ensemble inference (averaging features or logits from 2+ images).
|
||||
- Add species-constrained softmax that renormalizes probabilities over only the disease classes belonging to a known species.
|
||||
|
||||
deliverables:
|
||||
|
||||
- `src/lib/ml/inference.ts` — updated with ensemble and constrained inference functions
|
||||
- `src/lib/ml/confidence.ts` — updated with species-aware confidence calibration
|
||||
|
||||
steps:
|
||||
|
||||
1. In `src/lib/ml/inference.ts`, add:
|
||||
- `runEnsembleInference(tensors: Float32Array[], topK?: number): Promise<InferenceResult>` — runs multiple images through the model, averages their logits, and returns top-K predictions. Averaging logits (before softmax) is preferred over averaging probabilities since it preserves confidence structure.
|
||||
- `speciesConstrainedSoftmax(logits: Float32Array, speciesClasses: number[]): Float32Array` — given the full 11,818-class logits and a list of class indices belonging to the user-specified species, compute softmax over only those indices and return a renormalized probability vector (zero everywhere else). The model output dimension (11,818) should be a configurable constant.
|
||||
- `runSpeciesConstrainedInference(tensor: Float32Array, speciesClassIndices: number[], topK?: number): Promise<InferenceResult>` — run inference then apply species-constrained softmax before extracting top-K.
|
||||
- `runEnsembleSpeciesConstrained(tensors: Float32Array[], speciesClassIndices: number[], topK?: number): Promise<InferenceResult>` — ensemble then constrain.
|
||||
|
||||
2. Export `CLASSIFIER_NUM_CLASSES` constant (11,818) and `SPECIES_CLASS_RANGES` (a map from species name → [startIndex, endIndex] in the model output) from a new constants file or from labels.ts.
|
||||
|
||||
3. In `src/lib/ml/confidence.ts`, add:
|
||||
- `calibrateSpeciesConfidence(rawProb: number, numDiseaseClasses: number): ConfidenceResult` — adjusts calibration factor based on how many disease classes the species has (fewer classes = higher effective confidence).
|
||||
- `getEnsembleConfidence(predictions: RawPrediction[][]): ConfidenceResult` — aggregate confidence from multiple images.
|
||||
|
||||
4. Create `src/lib/ml/species-class-ranges.ts` containing the mapping from species name → [class start index, class end index] in the 11,818-class model output. This is derived from the training dataset's `species_index.json` or `class_hierarchy.json`.
|
||||
|
||||
5. Handle edge cases:
|
||||
- If tensors array is empty → throw
|
||||
- If tensor length doesn't match expected model input → throw validation error
|
||||
- If species name not found in `SPECIES_CLASS_RANGES` → fall back to full softmax
|
||||
|
||||
tests:
|
||||
|
||||
- Unit: test logit averaging with 2 identical tensors → results should be identical to single inference
|
||||
- Unit: test logit averaging with 2 different tensors → verify averaged output
|
||||
- Unit: test species-constrained softmax — verify probabilities are zero outside the constrained indices
|
||||
- Unit: test constrained softmax sums to ~1.0 within the species class range
|
||||
- Unit: test ensemble + constrained combined pipeline
|
||||
|
||||
acceptance_criteria:
|
||||
|
||||
- `runEnsembleInference` accepts multiple tensors and returns averaged top-K predictions
|
||||
- `speciesConstrainedSoftmax` zeros out all classes outside the species range
|
||||
- `runSpeciesConstrainedInference` and `runEnsembleSpeciesConstrained` produce constrained results
|
||||
- Confidence calibration accounts for number of disease classes in the species
|
||||
|
||||
validation:
|
||||
|
||||
- `npx tsc --noEmit` passes
|
||||
- Unit tests pass with `npx vitest run src/lib/ml/ --reporter=verbose`
|
||||
|
||||
notes:
|
||||
|
||||
- The current mock model outputs 38 classes. These new functions target the 11,818-class model.
|
||||
- Until the real model loads, ensemble/constrained functions should still work with mock data (just with fewer classes).
|
||||
- The species-ranges file should be auto-generated from `data/organized/class_hierarchy.json` and checked into version control.
|
||||
Reference in New Issue
Block a user