- Next.js 16 App Router project with Tailwind CSS - Plant disease knowledge base (93 diseases, 25 plants) - Image upload with client+server preprocessing - ML inference pipeline with mock/demo fallback - Responsive results page with disease cards and treatment - Full test suite (285 passing tests)
205 lines
6.4 KiB
TypeScript
205 lines
6.4 KiB
TypeScript
/**
|
||
* Confidence calibration and threshold logic for ML predictions.
|
||
*
|
||
* Provides softmax conversion, confidence calibration, and threshold-based
|
||
* filtering of predictions.
|
||
*/
|
||
|
||
import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types";
|
||
|
||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||
|
||
/** Minimum confidence threshold — predictions below this are filtered out */
|
||
export const DEFAULT_MIN_CONFIDENCE = 0.15;
|
||
|
||
/** Confidence label thresholds */
|
||
const CONFIDENCE_THRESHOLDS = {
|
||
HIGH: 0.8,
|
||
MEDIUM: 0.5,
|
||
} as const;
|
||
|
||
// ─── Softmax ─────────────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Apply softmax to a vector of logits, converting them to probabilities.
|
||
*
|
||
* Uses numerically stable softmax: subtracts max before exp() to avoid overflow.
|
||
*
|
||
* @param logits - Array of raw model output values
|
||
* @returns Array of probabilities that sum to ~1.0
|
||
*/
|
||
export function softmax(logits: number[]): number[] {
|
||
const maxLogit = Math.max(...logits);
|
||
const expValues = logits.map((l) => Math.exp(l - maxLogit));
|
||
const sumExp = expValues.reduce((a, b) => a + b, 0);
|
||
|
||
if (sumExp === 0) {
|
||
// Degenerate case: all logits are -Infinity
|
||
const uniform = 1 / logits.length;
|
||
return logits.map(() => uniform);
|
||
}
|
||
|
||
return expValues.map((e) => e / sumExp);
|
||
}
|
||
|
||
/**
|
||
* Apply softmax to a Float32Array of logits.
|
||
*
|
||
* @param logits - Float32Array of raw model output values
|
||
* @returns Float32Array of probabilities that sum to ~1.0
|
||
*/
|
||
export function softmaxFloat32(logits: Float32Array): Float32Array {
|
||
const maxLogit = -Infinity;
|
||
let actualMax = maxLogit;
|
||
for (let i = 0; i < logits.length; i++) {
|
||
if (logits[i] > actualMax) actualMax = logits[i];
|
||
}
|
||
|
||
const expValues = new Float32Array(logits.length);
|
||
let sumExp = 0;
|
||
for (let i = 0; i < logits.length; i++) {
|
||
expValues[i] = Math.exp(logits[i] - actualMax);
|
||
sumExp += expValues[i];
|
||
}
|
||
|
||
if (sumExp === 0) {
|
||
const uniform = 1 / logits.length;
|
||
return new Float32Array(logits.length).fill(uniform);
|
||
}
|
||
|
||
for (let i = 0; i < expValues.length; i++) {
|
||
expValues[i] /= sumExp;
|
||
}
|
||
|
||
return expValues;
|
||
}
|
||
|
||
// ─── Confidence Calibration ──────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Calibrate a raw probability into an adjusted confidence score with a label.
|
||
*
|
||
* Applies a mild calibration that slightly adjusts raw softmax probabilities
|
||
* to account for model overconfidence. Uses a linear calibration:
|
||
* adjusted = rawProb * calibrationFactor
|
||
* where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost
|
||
* well-separated predictions while keeping the value in [0, 1].
|
||
*
|
||
* The calibrated value is clamped to [0, 1] and labeled using thresholds:
|
||
* high ≥ 0.8
|
||
* medium ≥ 0.5
|
||
* low < 0.5
|
||
*
|
||
* @param rawProb - Raw softmax probability (0–1)
|
||
* @param calibrationFactor - Linear calibration factor (default 1.02)
|
||
* @returns { adjusted, label }
|
||
*/
|
||
export function calibrateConfidence(
|
||
rawProb: number,
|
||
calibrationFactor = 1.02,
|
||
): ConfidenceResult {
|
||
const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor));
|
||
const label = getConfidenceLabel(adjusted);
|
||
|
||
return {
|
||
raw: roundToDecimals(rawProb, 4),
|
||
adjusted: roundToDecimals(adjusted, 4),
|
||
label,
|
||
};
|
||
}
|
||
|
||
/**
|
||
* Get the confidence label for a given score.
|
||
*
|
||
* Thresholds:
|
||
* high ≥ 0.8
|
||
* medium ≥ 0.5
|
||
* low < 0.5
|
||
*
|
||
* @param score - Confidence score (0–1)
|
||
* @returns Confidence label
|
||
*/
|
||
export function getConfidenceLabel(score: number): ConfidenceLabel {
|
||
if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high";
|
||
if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium";
|
||
return "low";
|
||
}
|
||
|
||
/**
|
||
* Apply sigmoid function: 1 / (1 + exp(-x))
|
||
*/
|
||
function sigmoid(x: number): number {
|
||
return 1 / (1 + Math.exp(-x));
|
||
}
|
||
|
||
/**
|
||
* Round a number to a given number of decimal places.
|
||
*/
|
||
function roundToDecimals(value: number, decimals: number): number {
|
||
const factor = Math.pow(10, decimals);
|
||
return Math.round(value * factor) / factor;
|
||
}
|
||
|
||
// ─── Top-K Extraction ────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Extract the top-K predictions from a probability array.
|
||
*
|
||
* @param probabilities - Array of probabilities (from softmax)
|
||
* @param k - Number of top predictions to return (default 5)
|
||
* @returns Array of { classIndex, probability } sorted by probability descending
|
||
*/
|
||
export function getTopK(
|
||
probabilities: number[],
|
||
k = 5,
|
||
): RawPrediction[] {
|
||
// Create indexed pairs
|
||
const indexed = probabilities.map((prob, index) => ({
|
||
classIndex: index,
|
||
probability: prob,
|
||
}));
|
||
|
||
// Sort by probability descending
|
||
indexed.sort((a, b) => b.probability - a.probability);
|
||
|
||
// Take top K
|
||
return indexed.slice(0, k);
|
||
}
|
||
|
||
/**
|
||
* Extract top-K predictions from a Float32Array of probabilities.
|
||
*
|
||
* @param probabilities - Float32Array of probabilities
|
||
* @param k - Number of top predictions (default 5)
|
||
* @returns Array of { classIndex, probability } sorted descending
|
||
*/
|
||
export function getTopKFloat32(
|
||
probabilities: Float32Array,
|
||
k = 5,
|
||
): RawPrediction[] {
|
||
const indexed: Array<{ classIndex: number; probability: number }> = [];
|
||
for (let i = 0; i < probabilities.length; i++) {
|
||
indexed.push({ classIndex: i, probability: probabilities[i] });
|
||
}
|
||
|
||
indexed.sort((a, b) => b.probability - a.probability);
|
||
|
||
return indexed.slice(0, k);
|
||
}
|
||
|
||
// ─── Filtering ───────────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Filter predictions by minimum confidence threshold.
|
||
*
|
||
* @param predictions - Raw predictions from getTopK()
|
||
* @param minConfidence - Minimum probability threshold (default 0.15)
|
||
* @returns Filtered predictions array
|
||
*/
|
||
export function filterByConfidence(
|
||
predictions: RawPrediction[],
|
||
minConfidence = DEFAULT_MIN_CONFIDENCE,
|
||
): RawPrediction[] {
|
||
return predictions.filter((p) => p.probability >= minConfidence);
|
||
}
|