396 lines
12 KiB
TypeScript
396 lines
12 KiB
TypeScript
/**
|
|
* Singleton model loader for the plant disease classifier.
|
|
*
|
|
* Lazy-loads the TF.js or ONNX model on first call and caches it in memory
|
|
* via globalThis for subsequent requests. Supports graceful fallback to
|
|
* mock mode when no model file is present.
|
|
*
|
|
* Model files expected at: public/models/plant-disease-classifier/model.json
|
|
*/
|
|
|
|
import fs from "fs/promises";
|
|
import fsSync from "fs";
|
|
import path from "path";
|
|
|
|
// ─── Types ───────────────────────────────────────────────────────────────────
|
|
|
|
/** Model runtime backend */
|
|
export type ModelBackend = "tfjs" | "onnx" | "mock";
|
|
|
|
/** Model loading status */
|
|
export interface ModelStatus {
|
|
/** Whether a real model is loaded */
|
|
loaded: boolean;
|
|
/** Backend being used */
|
|
backend: ModelBackend;
|
|
/** Model identifier string */
|
|
modelId: string;
|
|
/** Number of output classes */
|
|
numClasses: number;
|
|
/** Error message if loading failed */
|
|
error?: string;
|
|
}
|
|
|
|
/** Result from running the model on input data */
|
|
export interface ModelOutput {
|
|
/** Raw logits or probabilities from the model */
|
|
logits: Float32Array;
|
|
/** Inference time in milliseconds */
|
|
inferenceTimeMs: number;
|
|
}
|
|
|
|
/** Model interface abstracted over TF.js / ONNX / mock */
|
|
export interface PlantDiseaseModel {
|
|
/** Run inference on a preprocessed image tensor */
|
|
predict(tensor: Float32Array): Promise<ModelOutput>;
|
|
/** Get model metadata */
|
|
getStatus(): ModelStatus;
|
|
}
|
|
|
|
// ─── Constants ───────────────────────────────────────────────────────────────
|
|
|
|
/** Path to model files relative to project root */
|
|
const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier");
|
|
const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json");
|
|
|
|
/** Model identifier */
|
|
export const MODEL_ID = "plant-classifier-v1";
|
|
|
|
/** Maximum model load time (ms) */
|
|
const MODEL_LOAD_TIMEOUT = 30_000;
|
|
|
|
// ─── Global cache ────────────────────────────────────────────────────────────
|
|
|
|
declare global {
|
|
var __plantDiseaseModel__: PlantDiseaseModel | undefined;
|
|
var __plantDiseaseModelLoading__: Promise<PlantDiseaseModel> | undefined;
|
|
}
|
|
|
|
// ─── Model Loader ────────────────────────────────────────────────────────────
|
|
|
|
/**
|
|
* Get the cached model instance, loading it lazily on first call.
|
|
* Uses globalThis to persist across serverless invocations (within same container).
|
|
*
|
|
* @returns Promise resolving to the model (real or mock)
|
|
*/
|
|
export async function getModel(): Promise<PlantDiseaseModel> {
|
|
// Return cached model if available
|
|
if (globalThis.__plantDiseaseModel__) {
|
|
return globalThis.__plantDiseaseModel__;
|
|
}
|
|
|
|
// If already loading, wait for the existing promise
|
|
if (globalThis.__plantDiseaseModelLoading__) {
|
|
return globalThis.__plantDiseaseModelLoading__;
|
|
}
|
|
|
|
// Start loading
|
|
const loadingPromise = loadModel();
|
|
globalThis.__plantDiseaseModelLoading__ = loadingPromise;
|
|
|
|
try {
|
|
const model = await Promise.race([
|
|
loadingPromise,
|
|
new Promise<never>((_, reject) =>
|
|
setTimeout(
|
|
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
|
|
MODEL_LOAD_TIMEOUT,
|
|
),
|
|
),
|
|
]);
|
|
|
|
globalThis.__plantDiseaseModel__ = model;
|
|
return model;
|
|
} finally {
|
|
globalThis.__plantDiseaseModelLoading__ = undefined;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Load the model, attempting TF.js first, then ONNX, then falling back to mock.
|
|
*/
|
|
async function loadModel(): Promise<PlantDiseaseModel> {
|
|
// Check if model files exist
|
|
const modelExists = await checkModelFiles();
|
|
|
|
if (!modelExists) {
|
|
console.warn(
|
|
`[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` +
|
|
`Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`,
|
|
);
|
|
return createMockModel();
|
|
}
|
|
|
|
// Try TF.js first
|
|
try {
|
|
const tfModel = await tryLoadTFJS();
|
|
if (tfModel) {
|
|
console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`);
|
|
return tfModel;
|
|
}
|
|
} catch (err) {
|
|
console.warn(
|
|
`[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`,
|
|
);
|
|
}
|
|
|
|
// Try ONNX Runtime
|
|
try {
|
|
const onnxModel = await tryLoadONNX();
|
|
if (onnxModel) {
|
|
console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`);
|
|
return onnxModel;
|
|
}
|
|
} catch (err) {
|
|
console.warn(
|
|
`[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`,
|
|
);
|
|
}
|
|
|
|
// Fall back to mock
|
|
console.warn(`[model-loader] All backends failed. Using mock model.`);
|
|
return createMockModel();
|
|
}
|
|
|
|
/**
|
|
* Check if model files exist on disk.
|
|
*/
|
|
async function checkModelFiles(): Promise<boolean> {
|
|
try {
|
|
await fs.access(MODEL_JSON_PATH);
|
|
return true;
|
|
} catch {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// ─── TensorFlow.js Backend ───────────────────────────────────────────────────
|
|
|
|
/**
|
|
* Try to load the model using TensorFlow.js.
|
|
* Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs.
|
|
*/
|
|
async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
let tf: any;
|
|
|
|
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
|
|
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
|
|
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
|
const nodeUtil = require("util");
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
|
|
return x === null || x === undefined;
|
|
};
|
|
}
|
|
|
|
// Try tfjs-node first (server-side, uses native bindings).
|
|
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
|
|
// as required dependencies — they are truly optional.
|
|
try {
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
|
const tfjsNode = await import("@tensorflow/tfjs-node" + "");
|
|
tf = tfjsNode;
|
|
} catch {
|
|
// Fall back to browser tfjs
|
|
try {
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
|
tf = await import("@tensorflow/tfjs" + "");
|
|
} catch {
|
|
return null; // Neither tfjs package available
|
|
}
|
|
}
|
|
|
|
// Load the model from file path
|
|
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
|
|
|
|
return {
|
|
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
|
const startTime = performance.now();
|
|
|
|
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
|
|
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
|
|
const inputTensor = tf
|
|
.tensor3d(Array.from(tensor), [3, 160, 160])
|
|
.transpose([1, 2, 0])
|
|
.expandDims(0);
|
|
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const outputTensor = (await model.predict(inputTensor)) as any;
|
|
const logits = new Float32Array(await outputTensor.data());
|
|
|
|
inputTensor.dispose();
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
|
outputTensor.dispose();
|
|
|
|
return {
|
|
logits,
|
|
inferenceTimeMs: performance.now() - startTime,
|
|
};
|
|
},
|
|
|
|
getStatus(): ModelStatus {
|
|
return {
|
|
loaded: true,
|
|
backend: "tfjs",
|
|
modelId: MODEL_ID,
|
|
numClasses: 38, // Original PlantVillage model
|
|
};
|
|
},
|
|
};
|
|
}
|
|
|
|
// ─── ONNX Runtime Backend ────────────────────────────────────────────────────
|
|
|
|
/**
|
|
* Try to load the model using ONNX Runtime.
|
|
*/
|
|
async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
let ort: any;
|
|
|
|
try {
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
|
ort = await import("onnxruntime-node" + "");
|
|
} catch {
|
|
return null;
|
|
}
|
|
|
|
// Look for .onnx file in model directory
|
|
const onnxPath = path.join(MODEL_DIR, "model.onnx");
|
|
const onnxExists = fsSync.existsSync(onnxPath);
|
|
|
|
if (!onnxExists) {
|
|
return null;
|
|
}
|
|
|
|
const session = await ort.InferenceSession.create(onnxPath);
|
|
|
|
return {
|
|
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
|
const startTime = performance.now();
|
|
|
|
// ONNX expects NCHW format: [1, 3, 160, 160]
|
|
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
|
|
const feeds = { [session.inputNames[0]]: inputTensor };
|
|
const results = await session.run(feeds);
|
|
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const outputValues = Object.values(results) as any[];
|
|
const logits = new Float32Array(outputValues[0].data);
|
|
|
|
inputTensor.dispose();
|
|
|
|
return {
|
|
logits,
|
|
inferenceTimeMs: performance.now() - startTime,
|
|
};
|
|
},
|
|
|
|
getStatus(): ModelStatus {
|
|
return {
|
|
loaded: true,
|
|
backend: "onnx",
|
|
modelId: MODEL_ID,
|
|
numClasses: 38,
|
|
};
|
|
},
|
|
};
|
|
}
|
|
|
|
// ─── Mock Model ──────────────────────────────────────────────────────────────
|
|
|
|
/**
|
|
* Create a deterministic mock model for development/demo mode.
|
|
*
|
|
* Generates reproducible predictions based on input tensor hash.
|
|
* This allows the UI to work without a real model file.
|
|
*/
|
|
function createMockModel(): PlantDiseaseModel {
|
|
return {
|
|
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
|
// Simulate inference time (50-200ms)
|
|
const simulatedTime = 50 + Math.random() * 150;
|
|
await sleep(simulatedTime);
|
|
|
|
// Generate deterministic logits from input hash
|
|
const logits = generateMockLogits(tensor);
|
|
|
|
return {
|
|
logits,
|
|
inferenceTimeMs: simulatedTime,
|
|
};
|
|
},
|
|
|
|
getStatus(): ModelStatus {
|
|
return {
|
|
loaded: false,
|
|
backend: "mock",
|
|
modelId: MODEL_ID,
|
|
numClasses: 38,
|
|
error: "Model files not found. Running in demo mode with mock predictions.",
|
|
};
|
|
},
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Generate deterministic mock logits from input tensor.
|
|
* Uses a simple hash of the first few tensor values to create
|
|
* reproducible but varied predictions.
|
|
*/
|
|
function generateMockLogits(tensor: Float32Array): Float32Array {
|
|
const numClasses = 38;
|
|
const logits = new Float32Array(numClasses);
|
|
|
|
// Simple hash of input for deterministic output
|
|
let hash = 0;
|
|
const sampleSize = Math.min(100, tensor.length);
|
|
for (let i = 0; i < sampleSize; i++) {
|
|
hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0;
|
|
}
|
|
|
|
// Generate logits using hash as seed
|
|
// Class 0 (healthy) gets a moderate score
|
|
logits[0] = (Math.abs(hash % 10) / 10) * 2;
|
|
|
|
// Give some disease classes higher scores
|
|
// This creates a realistic-looking distribution
|
|
for (let i = 1; i < numClasses - 1; i++) {
|
|
const seed = ((hash * (i + 1) * 7) % 1000) / 1000;
|
|
logits[i] = seed * 4 - 1; // Range roughly -1 to 3
|
|
}
|
|
|
|
// Make the top prediction more confident
|
|
const topIndex = Math.abs(hash % (numClasses - 2)) + 1;
|
|
logits[topIndex] = 3.5;
|
|
|
|
// Second highest
|
|
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
|
|
logits[secondIndex] = 2.5;
|
|
|
|
logits[numClasses - 1] = -2; // "unknown" gets low score
|
|
|
|
return logits;
|
|
}
|
|
|
|
/**
|
|
* Sleep for a given number of milliseconds.
|
|
*/
|
|
function sleep(ms: number): Promise<void> {
|
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
}
|
|
|
|
// ─── Reset (for testing) ─────────────────────────────────────────────────────
|
|
|
|
/**
|
|
* Reset the model cache. Useful for testing.
|
|
*/
|
|
export function resetModelCache(): void {
|
|
globalThis.__plantDiseaseModel__ = undefined;
|
|
globalThis.__plantDiseaseModelLoading__ = undefined;
|
|
}
|