138 lines
4.6 KiB
TypeScript
138 lines
4.6 KiB
TypeScript
/**
|
||
* ML inference pipeline for plant disease classification.
|
||
*
|
||
* Accepts a preprocessed image tensor, runs it through the model,
|
||
* applies softmax, extracts top-K predictions, and returns results
|
||
* with timing metadata.
|
||
*/
|
||
|
||
import type { InferenceResult, RawPrediction } from "@/lib/types";
|
||
import { getModel } from "./model-loader";
|
||
import { softmaxFloat32, getTopKFloat32 } from "./confidence";
|
||
|
||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||
|
||
/** Number of top predictions to return */
|
||
export const DEFAULT_TOP_K = 5;
|
||
|
||
/** Input tensor shape: [batch=1, channels=3, height=160, width=160] */
|
||
export const INPUT_SHAPE: [number, number, number, number] = [1, 3, 160, 160];
|
||
|
||
/** Expected input tensor length */
|
||
export const INPUT_SIZE = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3]; // 3 * 160 * 160 = 76800
|
||
|
||
// ─── Main Inference ──────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Run the full inference pipeline on a preprocessed image tensor.
|
||
*
|
||
* @param imageTensor - Normalized Float32Array of shape [1, 3, 160, 160] (NCHW)
|
||
* @param topK - Number of top predictions to return (default 5)
|
||
* @returns InferenceResult with top-K predictions and timing
|
||
*/
|
||
export async function runInference(
|
||
imageTensor: Float32Array,
|
||
topK = DEFAULT_TOP_K,
|
||
): Promise<InferenceResult> {
|
||
const startTime = performance.now();
|
||
|
||
// Validate input
|
||
validateInput(imageTensor);
|
||
|
||
// Get model (lazy loads on first call)
|
||
const model = await getModel();
|
||
|
||
// Run model forward pass
|
||
const { logits, inferenceTimeMs } = await model.predict(imageTensor);
|
||
|
||
// Apply softmax to convert logits to probabilities
|
||
const probabilities = softmaxFloat32(logits);
|
||
|
||
// Extract top-K predictions
|
||
const predictions = getTopKFloat32(probabilities, topK);
|
||
|
||
const totalTime = performance.now() - startTime;
|
||
|
||
return {
|
||
predictions,
|
||
inferenceTimeMs: Math.round(totalTime),
|
||
};
|
||
}
|
||
|
||
// ─── Input Validation ────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Validate that the input tensor has the expected shape and type.
|
||
*
|
||
* @param tensor - Input tensor to validate
|
||
* @throws Error if tensor is invalid
|
||
*/
|
||
export function validateInput(tensor: Float32Array): void {
|
||
if (!(tensor instanceof Float32Array)) {
|
||
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
|
||
}
|
||
|
||
if (tensor.length !== INPUT_SIZE) {
|
||
throw new Error(
|
||
`Expected tensor of length ${INPUT_SIZE} (shape ${INPUT_SHAPE.join("×")}), ` +
|
||
`got ${tensor.length}`,
|
||
);
|
||
}
|
||
|
||
// Check for NaN/Infinity values
|
||
for (let i = 0; i < tensor.length; i++) {
|
||
if (!Number.isFinite(tensor[i])) {
|
||
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
|
||
}
|
||
}
|
||
}
|
||
|
||
// ─── Batch Inference ─────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Run inference on multiple images.
|
||
*
|
||
* Currently runs sequentially. For true batching, the model itself would need
|
||
* to support batch input.
|
||
*
|
||
* @param tensors - Array of preprocessed image tensors
|
||
* @param topK - Number of top predictions per image
|
||
* @returns Array of inference results
|
||
*/
|
||
export async function runBatchInference(
|
||
tensors: Float32Array[],
|
||
topK = DEFAULT_TOP_K,
|
||
): Promise<InferenceResult[]> {
|
||
const results: InferenceResult[] = [];
|
||
|
||
for (const tensor of tensors) {
|
||
results.push(await runInference(tensor, topK));
|
||
}
|
||
|
||
return results;
|
||
}
|
||
|
||
// ─── Utility ─────────────────────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Create a zero-filled input tensor for testing.
|
||
*
|
||
* @returns Float32Array of shape [1, 3, 224, 224]
|
||
*/
|
||
export function createZeroTensor(): Float32Array {
|
||
return new Float32Array(INPUT_SIZE);
|
||
}
|
||
|
||
/**
|
||
* Create a random input tensor for testing.
|
||
*
|
||
* @returns Float32Array of shape [1, 3, 224, 224] with random values
|
||
*/
|
||
export function createRandomTensor(): Float32Array {
|
||
const tensor = new Float32Array(INPUT_SIZE);
|
||
for (let i = 0; i < tensor.length; i++) {
|
||
tensor[i] = (Math.random() * 2 - 1) * 2; // Range roughly -2 to 2
|
||
}
|
||
return tensor;
|
||
}
|