Files
plant-disease-id/tasks/production-ml-pipeline/04-confidence-calibration.md
2026-06-08 16:42:04 -04:00

208 lines
8.2 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 04. Confidence Calibration for PlantVillage Model
meta:
id: production-ml-pipeline-04
feature: production-ml-pipeline
priority: P1
depends_on: [production-ml-pipeline-03]
tags: [implementation, ml, tests-required]
objective:
- Implement proper confidence calibration for the PlantVillage model's softmax output
- Replace the trivial `raw * 1.02` linear calibration with temperature scaling or entropy-based confidence
- Produce meaningful confidence labels (high/medium/low) that correlate with actual correctness
- Handle the "healthy" class output correctly (healthy predictions need different confidence interpretation)
deliverables:
- `src/lib/ml/confidence.ts` — rewritten calibration with temperature scaling
- `src/lib/ml/calibration-params.ts` — calibration parameters (temperature, bias) for PlantVillage model
- `src/lib/ml/confidence.test.ts` — updated tests for new calibration logic
- `scripts/calibrate-model.ts` — script to compute optimal temperature from validation data
steps:
1. **Determine output type** — based on task 03's findings:
- If model output is already softmax probabilities: use entropy-based confidence or inverse-softmax + temperature scaling
- If model output is logits: apply temperature-scaled softmax directly
2. **Implement temperature scaling**:
```typescript
// src/lib/ml/confidence.ts
const DEFAULT_TEMPERATURE = 1.5; // Default for PlantVillage (typically 1.03.0)
export function temperatureScaledSoftmax(
logits: Float32Array,
temperature: number = DEFAULT_TEMPERATURE,
): Float32Array {
const scaled = new Float32Array(logits.length);
for (let i = 0; i < logits.length; i++) {
scaled[i] = logits[i] / temperature;
}
return softmaxFloat32(scaled);
}
```
- Temperature > 1.0 softens the distribution (less confident, more uniform)
- Temperature < 1.0 sharpens the distribution (more confident)
- Temperature = 1.0 is standard softmax (no calibration)
- Typical value for MobileNetV2 on PlantVillage: 1.21.8
3. **Implement entropy-based confidence**:
```typescript
export function computeEntropy(probabilities: Float32Array): number {
let entropy = 0;
for (let i = 0; i < probabilities.length; i++) {
if (probabilities[i] > 1e-10) {
entropy -= probabilities[i] * Math.log(probabilities[i]);
}
}
return entropy;
}
export function entropyToConfidence(
entropy: number,
maxEntropy: number, // ln(numClasses)
): number {
// Normalize entropy to [0, 1], then invert (low entropy = high confidence)
const normalized = entropy / maxEntropy;
return 1 - normalized;
}
```
- For 38 classes: `maxEntropy = Math.log(38) ≈ 3.64`
- Entropy close to 0 → one class dominates → high confidence
- Entropy close to max → uniform distribution → low confidence
4. **Implement combined calibration**:
```typescript
export function calibratePrediction(
output: Float32Array,
isLogits: boolean,
temperature: number = DEFAULT_TEMPERATURE,
): ConfidenceResult {
// Get probabilities (apply softmax if logits, or use directly if already probabilities)
const probs = isLogits ? temperatureScaledSoftmax(output, temperature) : output;
// Get top prediction
let maxIdx = 0;
for (let i = 1; i < probs.length; i++) {
if (probs[i] > probs[maxIdx]) maxIdx = i;
}
const topProb = probs[maxIdx];
// Compute entropy-based confidence
const entropy = computeEntropy(probs);
const maxEntropy = Math.log(probs.length);
const entropyConfidence = entropyToConfidence(entropy, maxEntropy);
// Combine: weighted average of top probability and entropy confidence
const adjusted = 0.7 * topProb + 0.3 * entropyConfidence;
return {
raw: topProb,
adjusted: Math.min(1, Math.max(0, adjusted)),
label: getConfidenceLabel(adjusted),
entropy,
classIndex: maxIdx,
};
}
```
5. **Update `getConfidenceLabel` thresholds** for PlantVillage's 38-class output:
```typescript
const CONFIDENCE_THRESHOLDS = {
HIGH: 0.65, // Lowered from 0.8 — PlantVillage softmax is less peaked
MEDIUM: 0.35, // Lowered from 0.5
} as const;
```
- With 38 classes, even correct predictions may have lower top probability
- These thresholds should be tuned against a validation set (start with defaults, adjust after testing)
6. **Handle healthy class confidence**:
- When the top prediction is a healthy class (index 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37), the confidence represents "how confident the model is the plant is healthy"
- Healthy predictions with high confidence → "No disease detected" (good)
- Healthy predictions with low confidence → "Uncertain — may have early symptoms"
- Update `calibrateConfidence()` to accept a `isHealthy` flag and adjust label accordingly
7. **Create calibration parameter module**:
```typescript
// src/lib/ml/calibration-params.ts
export const PLANTVILLAGE_CALIBRATION = {
temperature: 1.5,
confidenceHigh: 0.65,
confidenceMedium: 0.35,
maxEntropy: Math.log(38),
entropyWeight: 0.3,
probabilityWeight: 0.7,
} as const;
```
8. **Create calibration script** `scripts/calibrate-model.ts`:
- Load the model
- Run inference on a set of labeled validation images (from PlantVillage validation split)
- Compute optimal temperature using Nelder-Mead or grid search on negative log-likelihood
- Output the optimal temperature value
- This is optional — start with default 1.5 and refine later
9. **Update `InferenceResult` type** to include calibration metadata:
```typescript
export interface InferenceResult {
predictions: RawPrediction[];
inferenceTimeMs: number;
calibration?: {
temperature: number;
entropy: number;
entropyConfidence: number;
};
}
```
tests:
- Unit: `temperatureScaledSoftmax` with T=1.0 equals standard softmax
- Unit: `temperatureScaledSoftmax` with T=2.0 produces more uniform distribution than T=1.0
- Unit: `computeEntropy` of uniform distribution = `Math.log(38)` ≈ 3.64
- Unit: `computeEntropy` of one-hot distribution = 0
- Unit: `entropyToConfidence(0, maxEntropy)` = 1.0 (maximum confidence)
- Unit: `entropyToConfidence(maxEntropy, maxEntropy)` = 0.0 (minimum confidence)
- Unit: `calibratePrediction` with high-peak input returns high confidence
- Unit: `calibratePrediction` with flat input returns low confidence
- Unit: `getConfidenceLabel(0.7)` returns "high"
- Unit: `getConfidenceLabel(0.4)` returns "medium"
- Unit: `getConfidenceLabel(0.2)` returns "low"
- Integration: calibration on known PlantVillage test image produces reasonable confidence
acceptance_criteria:
- `calibratePrediction()` produces meaningful confidence scores that correlate with prediction quality
- Temperature scaling is implemented and configurable (default T=1.5)
- Entropy-based confidence is implemented
- Combined calibration (weighted probability + entropy) is the default
- Healthy class predictions are handled correctly
- Confidence thresholds are tuned for 38-class output (HIGH ≥ 0.65, MEDIUM ≥ 0.35)
- All unit tests pass
- Calibration parameters are documented and configurable
validation:
- `npx vitest run src/lib/ml/confidence.test.ts`
- Manual: run identification on a known disease image → confidence should be "high" (> 0.65)
- Manual: run identification on a random/unrelated image → confidence should be "low" (< 0.35)
- Check server logs: entropy values should be reasonable (1.03.5 range for 38 classes)
notes:
- Temperature scaling is a post-hoc calibration method — it doesn't change the model, only the confidence interpretation
- The default temperature of 1.5 is a reasonable starting point for MobileNetV2 on PlantVillage. Optimal value depends on the specific training run.
- If a validation set of PlantVillage images is available, run `scripts/calibrate-model.ts` to find the optimal temperature
- The entropy-based approach works even without a validation set — it's a model-agnostic confidence measure
- For healthy predictions, consider showing a different UI (e.g., "No disease detected" with confidence) rather than treating them as disease predictions