/** * Confidence calibration and threshold logic for ML predictions. * * Provides softmax conversion, confidence calibration, and threshold-based * filtering of predictions. */ import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types"; // ─── Configuration ─────────────────────────────────────────────────────────── /** Minimum confidence threshold — predictions below this are filtered out */ export const DEFAULT_MIN_CONFIDENCE = 0.15; /** Confidence label thresholds */ const CONFIDENCE_THRESHOLDS = { HIGH: 0.8, MEDIUM: 0.5, } as const; // ─── Softmax ───────────────────────────────────────────────────────────────── /** * Apply softmax to a vector of logits, converting them to probabilities. * * Uses numerically stable softmax: subtracts max before exp() to avoid overflow. * * @param logits - Array of raw model output values * @returns Array of probabilities that sum to ~1.0 */ export function softmax(logits: number[]): number[] { const maxLogit = Math.max(...logits); const expValues = logits.map((l) => Math.exp(l - maxLogit)); const sumExp = expValues.reduce((a, b) => a + b, 0); if (sumExp === 0) { // Degenerate case: all logits are -Infinity const uniform = 1 / logits.length; return logits.map(() => uniform); } return expValues.map((e) => e / sumExp); } /** * Apply softmax to a Float32Array of logits. * * @param logits - Float32Array of raw model output values * @returns Float32Array of probabilities that sum to ~1.0 */ export function softmaxFloat32(logits: Float32Array): Float32Array { const maxLogit = -Infinity; let actualMax = maxLogit; for (let i = 0; i < logits.length; i++) { if (logits[i] > actualMax) actualMax = logits[i]; } const expValues = new Float32Array(logits.length); let sumExp = 0; for (let i = 0; i < logits.length; i++) { expValues[i] = Math.exp(logits[i] - actualMax); sumExp += expValues[i]; } if (sumExp === 0) { const uniform = 1 / logits.length; return new Float32Array(logits.length).fill(uniform); } for (let i = 0; i < expValues.length; i++) { expValues[i] /= sumExp; } return expValues; } // ─── Confidence Calibration ────────────────────────────────────────────────── /** * Calibrate a raw probability into an adjusted confidence score with a label. * * Applies a mild calibration that slightly adjusts raw softmax probabilities * to account for model overconfidence. Uses a linear calibration: * adjusted = rawProb * calibrationFactor * where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost * well-separated predictions while keeping the value in [0, 1]. * * The calibrated value is clamped to [0, 1] and labeled using thresholds: * high ≥ 0.8 * medium ≥ 0.5 * low < 0.5 * * @param rawProb - Raw softmax probability (0–1) * @param calibrationFactor - Linear calibration factor (default 1.02) * @returns { adjusted, label } */ export function calibrateConfidence( rawProb: number, calibrationFactor = 1.02, ): ConfidenceResult { const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor)); const label = getConfidenceLabel(adjusted); return { raw: roundToDecimals(rawProb, 4), adjusted: roundToDecimals(adjusted, 4), label, }; } /** * Get the confidence label for a given score. * * Thresholds: * high ≥ 0.8 * medium ≥ 0.5 * low < 0.5 * * @param score - Confidence score (0–1) * @returns Confidence label */ export function getConfidenceLabel(score: number): ConfidenceLabel { if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high"; if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium"; return "low"; } /** * Apply sigmoid function: 1 / (1 + exp(-x)) */ function sigmoid(x: number): number { return 1 / (1 + Math.exp(-x)); } /** * Round a number to a given number of decimal places. */ function roundToDecimals(value: number, decimals: number): number { const factor = Math.pow(10, decimals); return Math.round(value * factor) / factor; } // ─── Top-K Extraction ──────────────────────────────────────────────────────── /** * Extract the top-K predictions from a probability array. * * @param probabilities - Array of probabilities (from softmax) * @param k - Number of top predictions to return (default 5) * @returns Array of { classIndex, probability } sorted by probability descending */ export function getTopK( probabilities: number[], k = 5, ): RawPrediction[] { // Create indexed pairs const indexed = probabilities.map((prob, index) => ({ classIndex: index, probability: prob, })); // Sort by probability descending indexed.sort((a, b) => b.probability - a.probability); // Take top K return indexed.slice(0, k); } /** * Extract top-K predictions from a Float32Array of probabilities. * * @param probabilities - Float32Array of probabilities * @param k - Number of top predictions (default 5) * @returns Array of { classIndex, probability } sorted descending */ export function getTopKFloat32( probabilities: Float32Array, k = 5, ): RawPrediction[] { const indexed: Array<{ classIndex: number; probability: number }> = []; for (let i = 0; i < probabilities.length; i++) { indexed.push({ classIndex: i, probability: probabilities[i] }); } indexed.sort((a, b) => b.probability - a.probability); return indexed.slice(0, k); } // ─── Filtering ─────────────────────────────────────────────────────────────── /** * Filter predictions by minimum confidence threshold. * * @param predictions - Raw predictions from getTopK() * @param minConfidence - Minimum probability threshold (default 0.15) * @returns Filtered predictions array */ export function filterByConfidence( predictions: RawPrediction[], minConfidence = DEFAULT_MIN_CONFIDENCE, ): RawPrediction[] { return predictions.filter((p) => p.probability >= minConfidence); }