Files
plant-disease-id/apps/web/src/lib/ml/inference.ts
2026-06-06 15:45:21 -04:00

138 lines
4.6 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* ML inference pipeline for plant disease classification.
*
* Accepts a preprocessed image tensor, runs it through the model,
* applies softmax, extracts top-K predictions, and returns results
* with timing metadata.
*/
import type { InferenceResult, RawPrediction } from "@/lib/types";
import { getModel } from "./model-loader";
import { softmaxFloat32, getTopKFloat32 } from "./confidence";
// ─── Configuration ───────────────────────────────────────────────────────────
/** Number of top predictions to return */
export const DEFAULT_TOP_K = 5;
/** Input tensor shape: [batch=1, channels=3, height=160, width=160] */
export const INPUT_SHAPE: [number, number, number, number] = [1, 3, 160, 160];
/** Expected input tensor length */
export const INPUT_SIZE = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3]; // 3 * 160 * 160 = 76800
// ─── Main Inference ──────────────────────────────────────────────────────────
/**
* Run the full inference pipeline on a preprocessed image tensor.
*
* @param imageTensor - Normalized Float32Array of shape [1, 3, 160, 160] (NCHW)
* @param topK - Number of top predictions to return (default 5)
* @returns InferenceResult with top-K predictions and timing
*/
export async function runInference(
imageTensor: Float32Array,
topK = DEFAULT_TOP_K,
): Promise<InferenceResult> {
const startTime = performance.now();
// Validate input
validateInput(imageTensor);
// Get model (lazy loads on first call)
const model = await getModel();
// Run model forward pass
const { logits, inferenceTimeMs } = await model.predict(imageTensor);
// Apply softmax to convert logits to probabilities
const probabilities = softmaxFloat32(logits);
// Extract top-K predictions
const predictions = getTopKFloat32(probabilities, topK);
const totalTime = performance.now() - startTime;
return {
predictions,
inferenceTimeMs: Math.round(totalTime),
};
}
// ─── Input Validation ────────────────────────────────────────────────────────
/**
* Validate that the input tensor has the expected shape and type.
*
* @param tensor - Input tensor to validate
* @throws Error if tensor is invalid
*/
export function validateInput(tensor: Float32Array): void {
if (!(tensor instanceof Float32Array)) {
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
}
if (tensor.length !== INPUT_SIZE) {
throw new Error(
`Expected tensor of length ${INPUT_SIZE} (shape ${INPUT_SHAPE.join("×")}), ` +
`got ${tensor.length}`,
);
}
// Check for NaN/Infinity values
for (let i = 0; i < tensor.length; i++) {
if (!Number.isFinite(tensor[i])) {
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
}
}
}
// ─── Batch Inference ─────────────────────────────────────────────────────────
/**
* Run inference on multiple images.
*
* Currently runs sequentially. For true batching, the model itself would need
* to support batch input.
*
* @param tensors - Array of preprocessed image tensors
* @param topK - Number of top predictions per image
* @returns Array of inference results
*/
export async function runBatchInference(
tensors: Float32Array[],
topK = DEFAULT_TOP_K,
): Promise<InferenceResult[]> {
const results: InferenceResult[] = [];
for (const tensor of tensors) {
results.push(await runInference(tensor, topK));
}
return results;
}
// ─── Utility ─────────────────────────────────────────────────────────────────
/**
* Create a zero-filled input tensor for testing.
*
* @returns Float32Array of shape [1, 3, 224, 224]
*/
export function createZeroTensor(): Float32Array {
return new Float32Array(INPUT_SIZE);
}
/**
* Create a random input tensor for testing.
*
* @returns Float32Array of shape [1, 3, 224, 224] with random values
*/
export function createRandomTensor(): Float32Array {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < tensor.length; i++) {
tensor[i] = (Math.random() * 2 - 1) * 2; // Range roughly -2 to 2
}
return tensor;
}