/** * Singleton model loader for the plant disease classifier. * * Lazy-loads the TF.js or ONNX model on first call and caches it in memory * via globalThis for subsequent requests. Supports graceful fallback to * mock mode when no model file is present. * * Model files expected at: public/models/plant-disease-classifier/model.json */ import fs from "fs/promises"; import fsSync from "fs"; import path from "path"; // ─── Types ─────────────────────────────────────────────────────────────────── /** Model runtime backend */ export type ModelBackend = "tfjs" | "onnx" | "mock"; /** Model loading status */ export interface ModelStatus { /** Whether a real model is loaded */ loaded: boolean; /** Backend being used */ backend: ModelBackend; /** Model identifier string */ modelId: string; /** Number of output classes */ numClasses: number; /** Error message if loading failed */ error?: string; } /** Result from running the model on input data */ export interface ModelOutput { /** Raw logits or probabilities from the model */ logits: Float32Array; /** Inference time in milliseconds */ inferenceTimeMs: number; } /** Model interface abstracted over TF.js / ONNX / mock */ export interface PlantDiseaseModel { /** Run inference on a preprocessed image tensor */ predict(tensor: Float32Array): Promise; /** Get model metadata */ getStatus(): ModelStatus; } // ─── Constants ─────────────────────────────────────────────────────────────── /** Path to model files relative to project root */ const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier"); const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json"); /** Model identifier */ export const MODEL_ID = "plant-classifier-v1"; /** Maximum model load time (ms) */ const MODEL_LOAD_TIMEOUT = 30_000; // ─── Global cache ──────────────────────────────────────────────────────────── declare global { var __plantDiseaseModel__: PlantDiseaseModel | undefined; var __plantDiseaseModelLoading__: Promise | undefined; } // ─── Model Loader ──────────────────────────────────────────────────────────── /** * Get the cached model instance, loading it lazily on first call. * Uses globalThis to persist across serverless invocations (within same container). * * @returns Promise resolving to the model (real or mock) */ export async function getModel(): Promise { // Return cached model if available if (globalThis.__plantDiseaseModel__) { return globalThis.__plantDiseaseModel__; } // If already loading, wait for the existing promise if (globalThis.__plantDiseaseModelLoading__) { return globalThis.__plantDiseaseModelLoading__; } // Start loading const loadingPromise = loadModel(); globalThis.__plantDiseaseModelLoading__ = loadingPromise; try { const model = await Promise.race([ loadingPromise, new Promise((_, reject) => setTimeout( () => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)), MODEL_LOAD_TIMEOUT, ), ), ]); globalThis.__plantDiseaseModel__ = model; return model; } finally { globalThis.__plantDiseaseModelLoading__ = undefined; } } /** * Load the model, attempting TF.js first, then ONNX, then falling back to mock. */ async function loadModel(): Promise { // Check if model files exist const modelExists = await checkModelFiles(); if (!modelExists) { console.warn( `[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` + `Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`, ); return createMockModel(); } // Try TF.js first try { const tfModel = await tryLoadTFJS(); if (tfModel) { console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`); return tfModel; } } catch (err) { console.warn( `[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`, ); } // Try ONNX Runtime try { const onnxModel = await tryLoadONNX(); if (onnxModel) { console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`); return onnxModel; } } catch (err) { console.warn( `[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`, ); } // Fall back to mock console.warn(`[model-loader] All backends failed. Using mock model.`); return createMockModel(); } /** * Check if model files exist on disk. */ async function checkModelFiles(): Promise { try { await fs.access(MODEL_JSON_PATH); return true; } catch { return false; } } // ─── TensorFlow.js Backend ─────────────────────────────────────────────────── /** * Try to load the model using TensorFlow.js. * Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs. */ async function tryLoadTFJS(): Promise { // eslint-disable-next-line @typescript-eslint/no-explicit-any let tf: any; // Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility. // @tensorflow/tfjs-node references this function which was removed in Node 15+. // eslint-disable-next-line @typescript-eslint/no-require-imports const nodeUtil = require("util"); // eslint-disable-next-line @typescript-eslint/no-explicit-any if (typeof (nodeUtil as any).isNullOrUndefined !== "function") { // eslint-disable-next-line @typescript-eslint/no-explicit-any (nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean { return x === null || x === undefined; }; } // Try tfjs-node first (server-side, uses native bindings). // Use dynamic strings so bundlers (Turbopack/webpack) don't trace these // as required dependencies — they are truly optional. try { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment const tfjsNode = await import("@tensorflow/tfjs-node" + ""); tf = tfjsNode; } catch { // Fall back to browser tfjs try { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment tf = await import("@tensorflow/tfjs" + ""); } catch { return null; // Neither tfjs package available } } // Load the model from file path const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`); return { async predict(tensor: Float32Array): Promise { const startTime = performance.now(); // Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js // Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3] const inputTensor = tf .tensor3d(Array.from(tensor), [3, 160, 160]) .transpose([1, 2, 0]) .expandDims(0); // eslint-disable-next-line @typescript-eslint/no-explicit-any const outputTensor = (await model.predict(inputTensor)) as any; const logits = new Float32Array(await outputTensor.data()); inputTensor.dispose(); // eslint-disable-next-line @typescript-eslint/no-unsafe-call outputTensor.dispose(); return { logits, inferenceTimeMs: performance.now() - startTime, }; }, getStatus(): ModelStatus { return { loaded: true, backend: "tfjs", modelId: MODEL_ID, numClasses: 38, // Original PlantVillage model }; }, }; } // ─── ONNX Runtime Backend ──────────────────────────────────────────────────── /** * Try to load the model using ONNX Runtime. */ async function tryLoadONNX(): Promise { // eslint-disable-next-line @typescript-eslint/no-explicit-any let ort: any; try { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment ort = await import("onnxruntime-node" + ""); } catch { return null; } // Look for .onnx file in model directory const onnxPath = path.join(MODEL_DIR, "model.onnx"); const onnxExists = fsSync.existsSync(onnxPath); if (!onnxExists) { return null; } const session = await ort.InferenceSession.create(onnxPath); return { async predict(tensor: Float32Array): Promise { const startTime = performance.now(); // ONNX expects NCHW format: [1, 3, 160, 160] const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]); const feeds = { [session.inputNames[0]]: inputTensor }; const results = await session.run(feeds); // eslint-disable-next-line @typescript-eslint/no-explicit-any const outputValues = Object.values(results) as any[]; const logits = new Float32Array(outputValues[0].data); inputTensor.dispose(); return { logits, inferenceTimeMs: performance.now() - startTime, }; }, getStatus(): ModelStatus { return { loaded: true, backend: "onnx", modelId: MODEL_ID, numClasses: 38, }; }, }; } // ─── Mock Model ────────────────────────────────────────────────────────────── /** * Create a deterministic mock model for development/demo mode. * * Generates reproducible predictions based on input tensor hash. * This allows the UI to work without a real model file. */ function createMockModel(): PlantDiseaseModel { return { async predict(tensor: Float32Array): Promise { // Simulate inference time (50-200ms) const simulatedTime = 50 + Math.random() * 150; await sleep(simulatedTime); // Generate deterministic logits from input hash const logits = generateMockLogits(tensor); return { logits, inferenceTimeMs: simulatedTime, }; }, getStatus(): ModelStatus { return { loaded: false, backend: "mock", modelId: MODEL_ID, numClasses: 38, error: "Model files not found. Running in demo mode with mock predictions.", }; }, }; } /** * Generate deterministic mock logits from input tensor. * Uses a simple hash of the first few tensor values to create * reproducible but varied predictions. */ function generateMockLogits(tensor: Float32Array): Float32Array { const numClasses = 38; const logits = new Float32Array(numClasses); // Simple hash of input for deterministic output let hash = 0; const sampleSize = Math.min(100, tensor.length); for (let i = 0; i < sampleSize; i++) { hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0; } // Generate logits using hash as seed // Class 0 (healthy) gets a moderate score logits[0] = (Math.abs(hash % 10) / 10) * 2; // Give some disease classes higher scores // This creates a realistic-looking distribution for (let i = 1; i < numClasses - 1; i++) { const seed = ((hash * (i + 1) * 7) % 1000) / 1000; logits[i] = seed * 4 - 1; // Range roughly -1 to 3 } // Make the top prediction more confident const topIndex = Math.abs(hash % (numClasses - 2)) + 1; logits[topIndex] = 3.5; // Second highest const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1; logits[secondIndex] = 2.5; logits[numClasses - 1] = -2; // "unknown" gets low score return logits; } /** * Sleep for a given number of milliseconds. */ function sleep(ms: number): Promise { return new Promise((resolve) => setTimeout(resolve, ms)); } // ─── Reset (for testing) ───────────────────────────────────────────────────── /** * Reset the model cache. Useful for testing. */ export function resetModelCache(): void { globalThis.__plantDiseaseModel__ = undefined; globalThis.__plantDiseaseModelLoading__ = undefined; }