Initial commit: Plant Disease Identification app

- Next.js 16 App Router project with Tailwind CSS
- Plant disease knowledge base (93 diseases, 25 plants)
- Image upload with client+server preprocessing
- ML inference pipeline with mock/demo fallback
- Responsive results page with disease cards and treatment
- Full test suite (285 passing tests)
This commit is contained in:
2026-06-05 19:21:16 -04:00
commit 820a872f07
100 changed files with 23271 additions and 0 deletions

View File

@@ -0,0 +1,204 @@
/**
* Confidence calibration and threshold logic for ML predictions.
*
* Provides softmax conversion, confidence calibration, and threshold-based
* filtering of predictions.
*/
import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types";
// ─── Configuration ───────────────────────────────────────────────────────────
/** Minimum confidence threshold — predictions below this are filtered out */
export const DEFAULT_MIN_CONFIDENCE = 0.15;
/** Confidence label thresholds */
const CONFIDENCE_THRESHOLDS = {
HIGH: 0.8,
MEDIUM: 0.5,
} as const;
// ─── Softmax ─────────────────────────────────────────────────────────────────
/**
* Apply softmax to a vector of logits, converting them to probabilities.
*
* Uses numerically stable softmax: subtracts max before exp() to avoid overflow.
*
* @param logits - Array of raw model output values
* @returns Array of probabilities that sum to ~1.0
*/
export function softmax(logits: number[]): number[] {
const maxLogit = Math.max(...logits);
const expValues = logits.map((l) => Math.exp(l - maxLogit));
const sumExp = expValues.reduce((a, b) => a + b, 0);
if (sumExp === 0) {
// Degenerate case: all logits are -Infinity
const uniform = 1 / logits.length;
return logits.map(() => uniform);
}
return expValues.map((e) => e / sumExp);
}
/**
* Apply softmax to a Float32Array of logits.
*
* @param logits - Float32Array of raw model output values
* @returns Float32Array of probabilities that sum to ~1.0
*/
export function softmaxFloat32(logits: Float32Array): Float32Array {
const maxLogit = -Infinity;
let actualMax = maxLogit;
for (let i = 0; i < logits.length; i++) {
if (logits[i] > actualMax) actualMax = logits[i];
}
const expValues = new Float32Array(logits.length);
let sumExp = 0;
for (let i = 0; i < logits.length; i++) {
expValues[i] = Math.exp(logits[i] - actualMax);
sumExp += expValues[i];
}
if (sumExp === 0) {
const uniform = 1 / logits.length;
return new Float32Array(logits.length).fill(uniform);
}
for (let i = 0; i < expValues.length; i++) {
expValues[i] /= sumExp;
}
return expValues;
}
// ─── Confidence Calibration ──────────────────────────────────────────────────
/**
* Calibrate a raw probability into an adjusted confidence score with a label.
*
* Applies a mild calibration that slightly adjusts raw softmax probabilities
* to account for model overconfidence. Uses a linear calibration:
* adjusted = rawProb * calibrationFactor
* where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost
* well-separated predictions while keeping the value in [0, 1].
*
* The calibrated value is clamped to [0, 1] and labeled using thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param rawProb - Raw softmax probability (01)
* @param calibrationFactor - Linear calibration factor (default 1.02)
* @returns { adjusted, label }
*/
export function calibrateConfidence(
rawProb: number,
calibrationFactor = 1.02,
): ConfidenceResult {
const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor));
const label = getConfidenceLabel(adjusted);
return {
raw: roundToDecimals(rawProb, 4),
adjusted: roundToDecimals(adjusted, 4),
label,
};
}
/**
* Get the confidence label for a given score.
*
* Thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param score - Confidence score (01)
* @returns Confidence label
*/
export function getConfidenceLabel(score: number): ConfidenceLabel {
if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high";
if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium";
return "low";
}
/**
* Apply sigmoid function: 1 / (1 + exp(-x))
*/
function sigmoid(x: number): number {
return 1 / (1 + Math.exp(-x));
}
/**
* Round a number to a given number of decimal places.
*/
function roundToDecimals(value: number, decimals: number): number {
const factor = Math.pow(10, decimals);
return Math.round(value * factor) / factor;
}
// ─── Top-K Extraction ────────────────────────────────────────────────────────
/**
* Extract the top-K predictions from a probability array.
*
* @param probabilities - Array of probabilities (from softmax)
* @param k - Number of top predictions to return (default 5)
* @returns Array of { classIndex, probability } sorted by probability descending
*/
export function getTopK(
probabilities: number[],
k = 5,
): RawPrediction[] {
// Create indexed pairs
const indexed = probabilities.map((prob, index) => ({
classIndex: index,
probability: prob,
}));
// Sort by probability descending
indexed.sort((a, b) => b.probability - a.probability);
// Take top K
return indexed.slice(0, k);
}
/**
* Extract top-K predictions from a Float32Array of probabilities.
*
* @param probabilities - Float32Array of probabilities
* @param k - Number of top predictions (default 5)
* @returns Array of { classIndex, probability } sorted descending
*/
export function getTopKFloat32(
probabilities: Float32Array,
k = 5,
): RawPrediction[] {
const indexed: Array<{ classIndex: number; probability: number }> = [];
for (let i = 0; i < probabilities.length; i++) {
indexed.push({ classIndex: i, probability: probabilities[i] });
}
indexed.sort((a, b) => b.probability - a.probability);
return indexed.slice(0, k);
}
// ─── Filtering ───────────────────────────────────────────────────────────────
/**
* Filter predictions by minimum confidence threshold.
*
* @param predictions - Raw predictions from getTopK()
* @param minConfidence - Minimum probability threshold (default 0.15)
* @returns Filtered predictions array
*/
export function filterByConfidence(
predictions: RawPrediction[],
minConfidence = DEFAULT_MIN_CONFIDENCE,
): RawPrediction[] {
return predictions.filter((p) => p.probability >= minConfidence);
}