Files
plant-disease-id/apps/web/src/lib/ml/model-loader.ts
2026-06-06 15:45:21 -04:00

396 lines
12 KiB
TypeScript

/**
* Singleton model loader for the plant disease classifier.
*
* Lazy-loads the TF.js or ONNX model on first call and caches it in memory
* via globalThis for subsequent requests. Supports graceful fallback to
* mock mode when no model file is present.
*
* Model files expected at: public/models/plant-disease-classifier/model.json
*/
import fs from "fs/promises";
import fsSync from "fs";
import path from "path";
// ─── Types ───────────────────────────────────────────────────────────────────
/** Model runtime backend */
export type ModelBackend = "tfjs" | "onnx" | "mock";
/** Model loading status */
export interface ModelStatus {
/** Whether a real model is loaded */
loaded: boolean;
/** Backend being used */
backend: ModelBackend;
/** Model identifier string */
modelId: string;
/** Number of output classes */
numClasses: number;
/** Error message if loading failed */
error?: string;
}
/** Result from running the model on input data */
export interface ModelOutput {
/** Raw logits or probabilities from the model */
logits: Float32Array;
/** Inference time in milliseconds */
inferenceTimeMs: number;
}
/** Model interface abstracted over TF.js / ONNX / mock */
export interface PlantDiseaseModel {
/** Run inference on a preprocessed image tensor */
predict(tensor: Float32Array): Promise<ModelOutput>;
/** Get model metadata */
getStatus(): ModelStatus;
}
// ─── Constants ───────────────────────────────────────────────────────────────
/** Path to model files relative to project root */
const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier");
const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json");
/** Model identifier */
export const MODEL_ID = "plant-classifier-v1";
/** Maximum model load time (ms) */
const MODEL_LOAD_TIMEOUT = 30_000;
// ─── Global cache ────────────────────────────────────────────────────────────
declare global {
var __plantDiseaseModel__: PlantDiseaseModel | undefined;
var __plantDiseaseModelLoading__: Promise<PlantDiseaseModel> | undefined;
}
// ─── Model Loader ────────────────────────────────────────────────────────────
/**
* Get the cached model instance, loading it lazily on first call.
* Uses globalThis to persist across serverless invocations (within same container).
*
* @returns Promise resolving to the model (real or mock)
*/
export async function getModel(): Promise<PlantDiseaseModel> {
// Return cached model if available
if (globalThis.__plantDiseaseModel__) {
return globalThis.__plantDiseaseModel__;
}
// If already loading, wait for the existing promise
if (globalThis.__plantDiseaseModelLoading__) {
return globalThis.__plantDiseaseModelLoading__;
}
// Start loading
const loadingPromise = loadModel();
globalThis.__plantDiseaseModelLoading__ = loadingPromise;
try {
const model = await Promise.race([
loadingPromise,
new Promise<never>((_, reject) =>
setTimeout(
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
MODEL_LOAD_TIMEOUT,
),
),
]);
globalThis.__plantDiseaseModel__ = model;
return model;
} finally {
globalThis.__plantDiseaseModelLoading__ = undefined;
}
}
/**
* Load the model, attempting TF.js first, then ONNX, then falling back to mock.
*/
async function loadModel(): Promise<PlantDiseaseModel> {
// Check if model files exist
const modelExists = await checkModelFiles();
if (!modelExists) {
console.warn(
`[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` +
`Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`,
);
return createMockModel();
}
// Try TF.js first
try {
const tfModel = await tryLoadTFJS();
if (tfModel) {
console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`);
return tfModel;
}
} catch (err) {
console.warn(
`[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`,
);
}
// Try ONNX Runtime
try {
const onnxModel = await tryLoadONNX();
if (onnxModel) {
console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`);
return onnxModel;
}
} catch (err) {
console.warn(
`[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`,
);
}
// Fall back to mock
console.warn(`[model-loader] All backends failed. Using mock model.`);
return createMockModel();
}
/**
* Check if model files exist on disk.
*/
async function checkModelFiles(): Promise<boolean> {
try {
await fs.access(MODEL_JSON_PATH);
return true;
} catch {
return false;
}
}
// ─── TensorFlow.js Backend ───────────────────────────────────────────────────
/**
* Try to load the model using TensorFlow.js.
* Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs.
*/
async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let tf: any;
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
// eslint-disable-next-line @typescript-eslint/no-require-imports
const nodeUtil = require("util");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
return x === null || x === undefined;
};
}
// Try tfjs-node first (server-side, uses native bindings).
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
// as required dependencies — they are truly optional.
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const tfjsNode = await import("@tensorflow/tfjs-node" + "");
tf = tfjsNode;
} catch {
// Fall back to browser tfjs
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
tf = await import("@tensorflow/tfjs" + "");
} catch {
return null; // Neither tfjs package available
}
}
// Load the model from file path
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
const inputTensor = tf
.tensor3d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0])
.expandDims(0);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const outputTensor = (await model.predict(inputTensor)) as any;
const logits = new Float32Array(await outputTensor.data());
inputTensor.dispose();
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
outputTensor.dispose();
return {
logits,
inferenceTimeMs: performance.now() - startTime,
};
},
getStatus(): ModelStatus {
return {
loaded: true,
backend: "tfjs",
modelId: MODEL_ID,
numClasses: 38, // Original PlantVillage model
};
},
};
}
// ─── ONNX Runtime Backend ────────────────────────────────────────────────────
/**
* Try to load the model using ONNX Runtime.
*/
async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let ort: any;
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
ort = await import("onnxruntime-node" + "");
} catch {
return null;
}
// Look for .onnx file in model directory
const onnxPath = path.join(MODEL_DIR, "model.onnx");
const onnxExists = fsSync.existsSync(onnxPath);
if (!onnxExists) {
return null;
}
const session = await ort.InferenceSession.create(onnxPath);
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// ONNX expects NCHW format: [1, 3, 160, 160]
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
const feeds = { [session.inputNames[0]]: inputTensor };
const results = await session.run(feeds);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const outputValues = Object.values(results) as any[];
const logits = new Float32Array(outputValues[0].data);
inputTensor.dispose();
return {
logits,
inferenceTimeMs: performance.now() - startTime,
};
},
getStatus(): ModelStatus {
return {
loaded: true,
backend: "onnx",
modelId: MODEL_ID,
numClasses: 38,
};
},
};
}
// ─── Mock Model ──────────────────────────────────────────────────────────────
/**
* Create a deterministic mock model for development/demo mode.
*
* Generates reproducible predictions based on input tensor hash.
* This allows the UI to work without a real model file.
*/
function createMockModel(): PlantDiseaseModel {
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
// Simulate inference time (50-200ms)
const simulatedTime = 50 + Math.random() * 150;
await sleep(simulatedTime);
// Generate deterministic logits from input hash
const logits = generateMockLogits(tensor);
return {
logits,
inferenceTimeMs: simulatedTime,
};
},
getStatus(): ModelStatus {
return {
loaded: false,
backend: "mock",
modelId: MODEL_ID,
numClasses: 38,
error: "Model files not found. Running in demo mode with mock predictions.",
};
},
};
}
/**
* Generate deterministic mock logits from input tensor.
* Uses a simple hash of the first few tensor values to create
* reproducible but varied predictions.
*/
function generateMockLogits(tensor: Float32Array): Float32Array {
const numClasses = 38;
const logits = new Float32Array(numClasses);
// Simple hash of input for deterministic output
let hash = 0;
const sampleSize = Math.min(100, tensor.length);
for (let i = 0; i < sampleSize; i++) {
hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0;
}
// Generate logits using hash as seed
// Class 0 (healthy) gets a moderate score
logits[0] = (Math.abs(hash % 10) / 10) * 2;
// Give some disease classes higher scores
// This creates a realistic-looking distribution
for (let i = 1; i < numClasses - 1; i++) {
const seed = ((hash * (i + 1) * 7) % 1000) / 1000;
logits[i] = seed * 4 - 1; // Range roughly -1 to 3
}
// Make the top prediction more confident
const topIndex = Math.abs(hash % (numClasses - 2)) + 1;
logits[topIndex] = 3.5;
// Second highest
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
logits[secondIndex] = 2.5;
logits[numClasses - 1] = -2; // "unknown" gets low score
return logits;
}
/**
* Sleep for a given number of milliseconds.
*/
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
// ─── Reset (for testing) ─────────────────────────────────────────────────────
/**
* Reset the model cache. Useful for testing.
*/
export function resetModelCache(): void {
globalThis.__plantDiseaseModel__ = undefined;
globalThis.__plantDiseaseModelLoading__ = undefined;
}