Files
plant-disease-id/apps/web/src/lib/ml/confidence.ts
Michael Freno 820a872f07 Initial commit: Plant Disease Identification app
- Next.js 16 App Router project with Tailwind CSS
- Plant disease knowledge base (93 diseases, 25 plants)
- Image upload with client+server preprocessing
- ML inference pipeline with mock/demo fallback
- Responsive results page with disease cards and treatment
- Full test suite (285 passing tests)
2026-06-05 19:21:16 -04:00

205 lines
6.4 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Confidence calibration and threshold logic for ML predictions.
*
* Provides softmax conversion, confidence calibration, and threshold-based
* filtering of predictions.
*/
import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types";
// ─── Configuration ───────────────────────────────────────────────────────────
/** Minimum confidence threshold — predictions below this are filtered out */
export const DEFAULT_MIN_CONFIDENCE = 0.15;
/** Confidence label thresholds */
const CONFIDENCE_THRESHOLDS = {
HIGH: 0.8,
MEDIUM: 0.5,
} as const;
// ─── Softmax ─────────────────────────────────────────────────────────────────
/**
* Apply softmax to a vector of logits, converting them to probabilities.
*
* Uses numerically stable softmax: subtracts max before exp() to avoid overflow.
*
* @param logits - Array of raw model output values
* @returns Array of probabilities that sum to ~1.0
*/
export function softmax(logits: number[]): number[] {
const maxLogit = Math.max(...logits);
const expValues = logits.map((l) => Math.exp(l - maxLogit));
const sumExp = expValues.reduce((a, b) => a + b, 0);
if (sumExp === 0) {
// Degenerate case: all logits are -Infinity
const uniform = 1 / logits.length;
return logits.map(() => uniform);
}
return expValues.map((e) => e / sumExp);
}
/**
* Apply softmax to a Float32Array of logits.
*
* @param logits - Float32Array of raw model output values
* @returns Float32Array of probabilities that sum to ~1.0
*/
export function softmaxFloat32(logits: Float32Array): Float32Array {
const maxLogit = -Infinity;
let actualMax = maxLogit;
for (let i = 0; i < logits.length; i++) {
if (logits[i] > actualMax) actualMax = logits[i];
}
const expValues = new Float32Array(logits.length);
let sumExp = 0;
for (let i = 0; i < logits.length; i++) {
expValues[i] = Math.exp(logits[i] - actualMax);
sumExp += expValues[i];
}
if (sumExp === 0) {
const uniform = 1 / logits.length;
return new Float32Array(logits.length).fill(uniform);
}
for (let i = 0; i < expValues.length; i++) {
expValues[i] /= sumExp;
}
return expValues;
}
// ─── Confidence Calibration ──────────────────────────────────────────────────
/**
* Calibrate a raw probability into an adjusted confidence score with a label.
*
* Applies a mild calibration that slightly adjusts raw softmax probabilities
* to account for model overconfidence. Uses a linear calibration:
* adjusted = rawProb * calibrationFactor
* where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost
* well-separated predictions while keeping the value in [0, 1].
*
* The calibrated value is clamped to [0, 1] and labeled using thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param rawProb - Raw softmax probability (01)
* @param calibrationFactor - Linear calibration factor (default 1.02)
* @returns { adjusted, label }
*/
export function calibrateConfidence(
rawProb: number,
calibrationFactor = 1.02,
): ConfidenceResult {
const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor));
const label = getConfidenceLabel(adjusted);
return {
raw: roundToDecimals(rawProb, 4),
adjusted: roundToDecimals(adjusted, 4),
label,
};
}
/**
* Get the confidence label for a given score.
*
* Thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param score - Confidence score (01)
* @returns Confidence label
*/
export function getConfidenceLabel(score: number): ConfidenceLabel {
if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high";
if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium";
return "low";
}
/**
* Apply sigmoid function: 1 / (1 + exp(-x))
*/
function sigmoid(x: number): number {
return 1 / (1 + Math.exp(-x));
}
/**
* Round a number to a given number of decimal places.
*/
function roundToDecimals(value: number, decimals: number): number {
const factor = Math.pow(10, decimals);
return Math.round(value * factor) / factor;
}
// ─── Top-K Extraction ────────────────────────────────────────────────────────
/**
* Extract the top-K predictions from a probability array.
*
* @param probabilities - Array of probabilities (from softmax)
* @param k - Number of top predictions to return (default 5)
* @returns Array of { classIndex, probability } sorted by probability descending
*/
export function getTopK(
probabilities: number[],
k = 5,
): RawPrediction[] {
// Create indexed pairs
const indexed = probabilities.map((prob, index) => ({
classIndex: index,
probability: prob,
}));
// Sort by probability descending
indexed.sort((a, b) => b.probability - a.probability);
// Take top K
return indexed.slice(0, k);
}
/**
* Extract top-K predictions from a Float32Array of probabilities.
*
* @param probabilities - Float32Array of probabilities
* @param k - Number of top predictions (default 5)
* @returns Array of { classIndex, probability } sorted descending
*/
export function getTopKFloat32(
probabilities: Float32Array,
k = 5,
): RawPrediction[] {
const indexed: Array<{ classIndex: number; probability: number }> = [];
for (let i = 0; i < probabilities.length; i++) {
indexed.push({ classIndex: i, probability: probabilities[i] });
}
indexed.sort((a, b) => b.probability - a.probability);
return indexed.slice(0, k);
}
// ─── Filtering ───────────────────────────────────────────────────────────────
/**
* Filter predictions by minimum confidence threshold.
*
* @param predictions - Raw predictions from getTopK()
* @param minConfidence - Minimum probability threshold (default 0.15)
* @returns Filtered predictions array
*/
export function filterByConfidence(
predictions: RawPrediction[],
minConfidence = DEFAULT_MIN_CONFIDENCE,
): RawPrediction[] {
return predictions.filter((p) => p.probability >= minConfidence);
}