208 lines
8.2 KiB
Markdown
208 lines
8.2 KiB
Markdown
# 04. Confidence Calibration for PlantVillage Model
|
||
|
||
meta:
|
||
id: production-ml-pipeline-04
|
||
feature: production-ml-pipeline
|
||
priority: P1
|
||
depends_on: [production-ml-pipeline-03]
|
||
tags: [implementation, ml, tests-required]
|
||
|
||
objective:
|
||
|
||
- Implement proper confidence calibration for the PlantVillage model's softmax output
|
||
- Replace the trivial `raw * 1.02` linear calibration with temperature scaling or entropy-based confidence
|
||
- Produce meaningful confidence labels (high/medium/low) that correlate with actual correctness
|
||
- Handle the "healthy" class output correctly (healthy predictions need different confidence interpretation)
|
||
|
||
deliverables:
|
||
|
||
- `src/lib/ml/confidence.ts` — rewritten calibration with temperature scaling
|
||
- `src/lib/ml/calibration-params.ts` — calibration parameters (temperature, bias) for PlantVillage model
|
||
- `src/lib/ml/confidence.test.ts` — updated tests for new calibration logic
|
||
- `scripts/calibrate-model.ts` — script to compute optimal temperature from validation data
|
||
|
||
steps:
|
||
|
||
1. **Determine output type** — based on task 03's findings:
|
||
- If model output is already softmax probabilities: use entropy-based confidence or inverse-softmax + temperature scaling
|
||
- If model output is logits: apply temperature-scaled softmax directly
|
||
|
||
2. **Implement temperature scaling**:
|
||
|
||
```typescript
|
||
// src/lib/ml/confidence.ts
|
||
const DEFAULT_TEMPERATURE = 1.5; // Default for PlantVillage (typically 1.0–3.0)
|
||
|
||
export function temperatureScaledSoftmax(
|
||
logits: Float32Array,
|
||
temperature: number = DEFAULT_TEMPERATURE,
|
||
): Float32Array {
|
||
const scaled = new Float32Array(logits.length);
|
||
for (let i = 0; i < logits.length; i++) {
|
||
scaled[i] = logits[i] / temperature;
|
||
}
|
||
return softmaxFloat32(scaled);
|
||
}
|
||
```
|
||
|
||
- Temperature > 1.0 softens the distribution (less confident, more uniform)
|
||
- Temperature < 1.0 sharpens the distribution (more confident)
|
||
- Temperature = 1.0 is standard softmax (no calibration)
|
||
- Typical value for MobileNetV2 on PlantVillage: 1.2–1.8
|
||
|
||
3. **Implement entropy-based confidence**:
|
||
|
||
```typescript
|
||
export function computeEntropy(probabilities: Float32Array): number {
|
||
let entropy = 0;
|
||
for (let i = 0; i < probabilities.length; i++) {
|
||
if (probabilities[i] > 1e-10) {
|
||
entropy -= probabilities[i] * Math.log(probabilities[i]);
|
||
}
|
||
}
|
||
return entropy;
|
||
}
|
||
|
||
export function entropyToConfidence(
|
||
entropy: number,
|
||
maxEntropy: number, // ln(numClasses)
|
||
): number {
|
||
// Normalize entropy to [0, 1], then invert (low entropy = high confidence)
|
||
const normalized = entropy / maxEntropy;
|
||
return 1 - normalized;
|
||
}
|
||
```
|
||
|
||
- For 38 classes: `maxEntropy = Math.log(38) ≈ 3.64`
|
||
- Entropy close to 0 → one class dominates → high confidence
|
||
- Entropy close to max → uniform distribution → low confidence
|
||
|
||
4. **Implement combined calibration**:
|
||
|
||
```typescript
|
||
export function calibratePrediction(
|
||
output: Float32Array,
|
||
isLogits: boolean,
|
||
temperature: number = DEFAULT_TEMPERATURE,
|
||
): ConfidenceResult {
|
||
// Get probabilities (apply softmax if logits, or use directly if already probabilities)
|
||
const probs = isLogits ? temperatureScaledSoftmax(output, temperature) : output;
|
||
|
||
// Get top prediction
|
||
let maxIdx = 0;
|
||
for (let i = 1; i < probs.length; i++) {
|
||
if (probs[i] > probs[maxIdx]) maxIdx = i;
|
||
}
|
||
const topProb = probs[maxIdx];
|
||
|
||
// Compute entropy-based confidence
|
||
const entropy = computeEntropy(probs);
|
||
const maxEntropy = Math.log(probs.length);
|
||
const entropyConfidence = entropyToConfidence(entropy, maxEntropy);
|
||
|
||
// Combine: weighted average of top probability and entropy confidence
|
||
const adjusted = 0.7 * topProb + 0.3 * entropyConfidence;
|
||
|
||
return {
|
||
raw: topProb,
|
||
adjusted: Math.min(1, Math.max(0, adjusted)),
|
||
label: getConfidenceLabel(adjusted),
|
||
entropy,
|
||
classIndex: maxIdx,
|
||
};
|
||
}
|
||
```
|
||
|
||
5. **Update `getConfidenceLabel` thresholds** for PlantVillage's 38-class output:
|
||
|
||
```typescript
|
||
const CONFIDENCE_THRESHOLDS = {
|
||
HIGH: 0.65, // Lowered from 0.8 — PlantVillage softmax is less peaked
|
||
MEDIUM: 0.35, // Lowered from 0.5
|
||
} as const;
|
||
```
|
||
|
||
- With 38 classes, even correct predictions may have lower top probability
|
||
- These thresholds should be tuned against a validation set (start with defaults, adjust after testing)
|
||
|
||
6. **Handle healthy class confidence**:
|
||
- When the top prediction is a healthy class (index 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37), the confidence represents "how confident the model is the plant is healthy"
|
||
- Healthy predictions with high confidence → "No disease detected" (good)
|
||
- Healthy predictions with low confidence → "Uncertain — may have early symptoms"
|
||
- Update `calibrateConfidence()` to accept a `isHealthy` flag and adjust label accordingly
|
||
|
||
7. **Create calibration parameter module**:
|
||
|
||
```typescript
|
||
// src/lib/ml/calibration-params.ts
|
||
export const PLANTVILLAGE_CALIBRATION = {
|
||
temperature: 1.5,
|
||
confidenceHigh: 0.65,
|
||
confidenceMedium: 0.35,
|
||
maxEntropy: Math.log(38),
|
||
entropyWeight: 0.3,
|
||
probabilityWeight: 0.7,
|
||
} as const;
|
||
```
|
||
|
||
8. **Create calibration script** `scripts/calibrate-model.ts`:
|
||
- Load the model
|
||
- Run inference on a set of labeled validation images (from PlantVillage validation split)
|
||
- Compute optimal temperature using Nelder-Mead or grid search on negative log-likelihood
|
||
- Output the optimal temperature value
|
||
- This is optional — start with default 1.5 and refine later
|
||
|
||
9. **Update `InferenceResult` type** to include calibration metadata:
|
||
```typescript
|
||
export interface InferenceResult {
|
||
predictions: RawPrediction[];
|
||
inferenceTimeMs: number;
|
||
calibration?: {
|
||
temperature: number;
|
||
entropy: number;
|
||
entropyConfidence: number;
|
||
};
|
||
}
|
||
```
|
||
|
||
tests:
|
||
|
||
- Unit: `temperatureScaledSoftmax` with T=1.0 equals standard softmax
|
||
- Unit: `temperatureScaledSoftmax` with T=2.0 produces more uniform distribution than T=1.0
|
||
- Unit: `computeEntropy` of uniform distribution = `Math.log(38)` ≈ 3.64
|
||
- Unit: `computeEntropy` of one-hot distribution = 0
|
||
- Unit: `entropyToConfidence(0, maxEntropy)` = 1.0 (maximum confidence)
|
||
- Unit: `entropyToConfidence(maxEntropy, maxEntropy)` = 0.0 (minimum confidence)
|
||
- Unit: `calibratePrediction` with high-peak input returns high confidence
|
||
- Unit: `calibratePrediction` with flat input returns low confidence
|
||
- Unit: `getConfidenceLabel(0.7)` returns "high"
|
||
- Unit: `getConfidenceLabel(0.4)` returns "medium"
|
||
- Unit: `getConfidenceLabel(0.2)` returns "low"
|
||
- Integration: calibration on known PlantVillage test image produces reasonable confidence
|
||
|
||
acceptance_criteria:
|
||
|
||
- `calibratePrediction()` produces meaningful confidence scores that correlate with prediction quality
|
||
- Temperature scaling is implemented and configurable (default T=1.5)
|
||
- Entropy-based confidence is implemented
|
||
- Combined calibration (weighted probability + entropy) is the default
|
||
- Healthy class predictions are handled correctly
|
||
- Confidence thresholds are tuned for 38-class output (HIGH ≥ 0.65, MEDIUM ≥ 0.35)
|
||
- All unit tests pass
|
||
- Calibration parameters are documented and configurable
|
||
|
||
validation:
|
||
|
||
- `npx vitest run src/lib/ml/confidence.test.ts`
|
||
- Manual: run identification on a known disease image → confidence should be "high" (> 0.65)
|
||
- Manual: run identification on a random/unrelated image → confidence should be "low" (< 0.35)
|
||
- Check server logs: entropy values should be reasonable (1.0–3.5 range for 38 classes)
|
||
|
||
notes:
|
||
|
||
- Temperature scaling is a post-hoc calibration method — it doesn't change the model, only the confidence interpretation
|
||
- The default temperature of 1.5 is a reasonable starting point for MobileNetV2 on PlantVillage. Optimal value depends on the specific training run.
|
||
- If a validation set of PlantVillage images is available, run `scripts/calibrate-model.ts` to find the optimal temperature
|
||
- The entropy-based approach works even without a validation set — it's a model-agnostic confidence measure
|
||
- For healthy predictions, consider showing a different UI (e.g., "No disease detected" with confidence) rather than treating them as disease predictions
|