/** * ML inference pipeline for plant disease classification. * * Accepts a preprocessed image tensor, runs it through the model, * applies softmax, extracts top-K predictions, and returns results * with timing metadata. */ import type { InferenceResult, RawPrediction } from "@/lib/types"; import { getModel } from "./model-loader"; import { softmaxFloat32, getTopKFloat32 } from "./confidence"; // ─── Configuration ─────────────────────────────────────────────────────────── /** Number of top predictions to return */ export const DEFAULT_TOP_K = 5; /** Input tensor shape: [batch=1, channels=3, height=160, width=160] */ export const INPUT_SHAPE: [number, number, number, number] = [1, 3, 160, 160]; /** Expected input tensor length */ export const INPUT_SIZE = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3]; // 3 * 160 * 160 = 76800 // ─── Main Inference ────────────────────────────────────────────────────────── /** * Run the full inference pipeline on a preprocessed image tensor. * * @param imageTensor - Normalized Float32Array of shape [1, 3, 160, 160] (NCHW) * @param topK - Number of top predictions to return (default 5) * @returns InferenceResult with top-K predictions and timing */ export async function runInference( imageTensor: Float32Array, topK = DEFAULT_TOP_K, ): Promise { const startTime = performance.now(); // Validate input validateInput(imageTensor); // Get model (lazy loads on first call) const model = await getModel(); // Run model forward pass const { logits, inferenceTimeMs } = await model.predict(imageTensor); // Apply softmax to convert logits to probabilities const probabilities = softmaxFloat32(logits); // Extract top-K predictions const predictions = getTopKFloat32(probabilities, topK); const totalTime = performance.now() - startTime; return { predictions, inferenceTimeMs: Math.round(totalTime), }; } // ─── Input Validation ──────────────────────────────────────────────────────── /** * Validate that the input tensor has the expected shape and type. * * @param tensor - Input tensor to validate * @throws Error if tensor is invalid */ export function validateInput(tensor: Float32Array): void { if (!(tensor instanceof Float32Array)) { throw new Error(`Expected Float32Array input, got ${typeof tensor}`); } if (tensor.length !== INPUT_SIZE) { throw new Error( `Expected tensor of length ${INPUT_SIZE} (shape ${INPUT_SHAPE.join("×")}), ` + `got ${tensor.length}`, ); } // Check for NaN/Infinity values for (let i = 0; i < tensor.length; i++) { if (!Number.isFinite(tensor[i])) { throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`); } } } // ─── Batch Inference ───────────────────────────────────────────────────────── /** * Run inference on multiple images. * * Currently runs sequentially. For true batching, the model itself would need * to support batch input. * * @param tensors - Array of preprocessed image tensors * @param topK - Number of top predictions per image * @returns Array of inference results */ export async function runBatchInference( tensors: Float32Array[], topK = DEFAULT_TOP_K, ): Promise { const results: InferenceResult[] = []; for (const tensor of tensors) { results.push(await runInference(tensor, topK)); } return results; } // ─── Utility ───────────────────────────────────────────────────────────────── /** * Create a zero-filled input tensor for testing. * * @returns Float32Array of shape [1, 3, 224, 224] */ export function createZeroTensor(): Float32Array { return new Float32Array(INPUT_SIZE); } /** * Create a random input tensor for testing. * * @returns Float32Array of shape [1, 3, 224, 224] with random values */ export function createRandomTensor(): Float32Array { const tensor = new Float32Array(INPUT_SIZE); for (let i = 0; i < tensor.length; i++) { tensor[i] = (Math.random() * 2 - 1) * 2; // Range roughly -2 to 2 } return tensor; }