Initial commit: Plant Disease Identification app
- Next.js 16 App Router project with Tailwind CSS - Plant disease knowledge base (93 diseases, 25 plants) - Image upload with client+server preprocessing - ML inference pipeline with mock/demo fallback - Responsive results page with disease cards and treatment - Full test suite (285 passing tests)
This commit is contained in:
378
apps/web/src/lib/ml/model-loader.ts
Normal file
378
apps/web/src/lib/ml/model-loader.ts
Normal file
@@ -0,0 +1,378 @@
|
||||
/**
|
||||
* Singleton model loader for the plant disease classifier.
|
||||
*
|
||||
* Lazy-loads the TF.js or ONNX model on first call and caches it in memory
|
||||
* via globalThis for subsequent requests. Supports graceful fallback to
|
||||
* mock mode when no model file is present.
|
||||
*
|
||||
* Model files expected at: public/models/plant-disease-classifier/model.json
|
||||
*/
|
||||
|
||||
import fs from "fs/promises";
|
||||
import fsSync from "fs";
|
||||
import path from "path";
|
||||
|
||||
// ─── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Model runtime backend */
|
||||
export type ModelBackend = "tfjs" | "onnx" | "mock";
|
||||
|
||||
/** Model loading status */
|
||||
export interface ModelStatus {
|
||||
/** Whether a real model is loaded */
|
||||
loaded: boolean;
|
||||
/** Backend being used */
|
||||
backend: ModelBackend;
|
||||
/** Model identifier string */
|
||||
modelId: string;
|
||||
/** Number of output classes */
|
||||
numClasses: number;
|
||||
/** Error message if loading failed */
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Result from running the model on input data */
|
||||
export interface ModelOutput {
|
||||
/** Raw logits or probabilities from the model */
|
||||
logits: Float32Array;
|
||||
/** Inference time in milliseconds */
|
||||
inferenceTimeMs: number;
|
||||
}
|
||||
|
||||
/** Model interface abstracted over TF.js / ONNX / mock */
|
||||
export interface PlantDiseaseModel {
|
||||
/** Run inference on a preprocessed image tensor */
|
||||
predict(tensor: Float32Array): Promise<ModelOutput>;
|
||||
/** Get model metadata */
|
||||
getStatus(): ModelStatus;
|
||||
}
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
/** Path to model files relative to project root */
|
||||
const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier");
|
||||
const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json");
|
||||
|
||||
/** Model identifier */
|
||||
export const MODEL_ID = "plant-classifier-v1";
|
||||
|
||||
/** Maximum model load time (ms) */
|
||||
const MODEL_LOAD_TIMEOUT = 30_000;
|
||||
|
||||
// ─── Global cache ────────────────────────────────────────────────────────────
|
||||
|
||||
declare global {
|
||||
var __plantDiseaseModel__: PlantDiseaseModel | undefined;
|
||||
var __plantDiseaseModelLoading__: Promise<PlantDiseaseModel> | undefined;
|
||||
}
|
||||
|
||||
// ─── Model Loader ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get the cached model instance, loading it lazily on first call.
|
||||
* Uses globalThis to persist across serverless invocations (within same container).
|
||||
*
|
||||
* @returns Promise resolving to the model (real or mock)
|
||||
*/
|
||||
export async function getModel(): Promise<PlantDiseaseModel> {
|
||||
// Return cached model if available
|
||||
if (globalThis.__plantDiseaseModel__) {
|
||||
return globalThis.__plantDiseaseModel__;
|
||||
}
|
||||
|
||||
// If already loading, wait for the existing promise
|
||||
if (globalThis.__plantDiseaseModelLoading__) {
|
||||
return globalThis.__plantDiseaseModelLoading__;
|
||||
}
|
||||
|
||||
// Start loading
|
||||
const loadingPromise = loadModel();
|
||||
globalThis.__plantDiseaseModelLoading__ = loadingPromise;
|
||||
|
||||
try {
|
||||
const model = await Promise.race([
|
||||
loadingPromise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)), MODEL_LOAD_TIMEOUT),
|
||||
),
|
||||
]);
|
||||
|
||||
globalThis.__plantDiseaseModel__ = model;
|
||||
return model;
|
||||
} finally {
|
||||
globalThis.__plantDiseaseModelLoading__ = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model, attempting TF.js first, then ONNX, then falling back to mock.
|
||||
*/
|
||||
async function loadModel(): Promise<PlantDiseaseModel> {
|
||||
// Check if model files exist
|
||||
const modelExists = await checkModelFiles();
|
||||
|
||||
if (!modelExists) {
|
||||
console.warn(
|
||||
`[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` +
|
||||
`Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`,
|
||||
);
|
||||
return createMockModel();
|
||||
}
|
||||
|
||||
// Try TF.js first
|
||||
try {
|
||||
const tfModel = await tryLoadTFJS();
|
||||
if (tfModel) {
|
||||
console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`);
|
||||
return tfModel;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
`[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`,
|
||||
);
|
||||
}
|
||||
|
||||
// Try ONNX Runtime
|
||||
try {
|
||||
const onnxModel = await tryLoadONNX();
|
||||
if (onnxModel) {
|
||||
console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`);
|
||||
return onnxModel;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
`[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Fall back to mock
|
||||
console.warn(`[model-loader] All backends failed. Using mock model.`);
|
||||
return createMockModel();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model files exist on disk.
|
||||
*/
|
||||
async function checkModelFiles(): Promise<boolean> {
|
||||
try {
|
||||
await fs.access(MODEL_JSON_PATH);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── TensorFlow.js Backend ───────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Try to load the model using TensorFlow.js.
|
||||
* Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs.
|
||||
*/
|
||||
async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let tf: any;
|
||||
|
||||
// Try tfjs-node first (server-side, uses native bindings).
|
||||
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
|
||||
// as required dependencies — they are truly optional.
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
const tfjsNode = await import("@tensorflow/tfjs-node" + "");
|
||||
tf = tfjsNode;
|
||||
} catch {
|
||||
// Fall back to browser tfjs
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
tf = await import("@tensorflow/tfjs" + "");
|
||||
} catch {
|
||||
return null; // Neither tfjs package available
|
||||
}
|
||||
}
|
||||
|
||||
// Load the model from file path
|
||||
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
|
||||
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Reshape to [1, 3, 224, 224] NCHW → [1, 224, 224, 3] NHWC for TF.js
|
||||
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 224, 224])
|
||||
.transpose([1, 2, 0])
|
||||
.expandDims(0);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const outputTensor = (await model.predict(inputTensor)) as any;
|
||||
const logits = new Float32Array(await outputTensor.data());
|
||||
|
||||
inputTensor.dispose();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
||||
outputTensor.dispose();
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: performance.now() - startTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: true,
|
||||
backend: "tfjs",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95, // Will be updated after model loads
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── ONNX Runtime Backend ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Try to load the model using ONNX Runtime.
|
||||
*/
|
||||
async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let ort: any;
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
ort = await import("onnxruntime-node" + "");
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Look for .onnx file in model directory
|
||||
const onnxPath = path.join(MODEL_DIR, "model.onnx");
|
||||
const onnxExists = fsSync.existsSync(onnxPath);
|
||||
|
||||
if (!onnxExists) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const session = await ort.InferenceSession.create(onnxPath);
|
||||
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// ONNX expects NCHW format: [1, 3, 224, 224]
|
||||
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 224, 224]);
|
||||
const feeds = { [session.inputNames[0]]: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const outputValues = Object.values(results) as any[];
|
||||
const logits = new Float32Array(outputValues[0].data);
|
||||
|
||||
inputTensor.dispose();
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: performance.now() - startTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: true,
|
||||
backend: "onnx",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95,
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Mock Model ──────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a deterministic mock model for development/demo mode.
|
||||
*
|
||||
* Generates reproducible predictions based on input tensor hash.
|
||||
* This allows the UI to work without a real model file.
|
||||
*/
|
||||
function createMockModel(): PlantDiseaseModel {
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
// Simulate inference time (50-200ms)
|
||||
const simulatedTime = 50 + Math.random() * 150;
|
||||
await sleep(simulatedTime);
|
||||
|
||||
// Generate deterministic logits from input hash
|
||||
const logits = generateMockLogits(tensor);
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: simulatedTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: false,
|
||||
backend: "mock",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95,
|
||||
error: "Model files not found. Running in demo mode with mock predictions.",
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate deterministic mock logits from input tensor.
|
||||
* Uses a simple hash of the first few tensor values to create
|
||||
* reproducible but varied predictions.
|
||||
*/
|
||||
function generateMockLogits(tensor: Float32Array): Float32Array {
|
||||
const numClasses = 95;
|
||||
const logits = new Float32Array(numClasses);
|
||||
|
||||
// Simple hash of input for deterministic output
|
||||
let hash = 0;
|
||||
const sampleSize = Math.min(100, tensor.length);
|
||||
for (let i = 0; i < sampleSize; i++) {
|
||||
hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0;
|
||||
}
|
||||
|
||||
// Generate logits using hash as seed
|
||||
// Class 0 (healthy) gets a moderate score
|
||||
logits[0] = (Math.abs(hash % 10) / 10) * 2;
|
||||
|
||||
// Give some disease classes higher scores
|
||||
// This creates a realistic-looking distribution
|
||||
for (let i = 1; i < numClasses - 1; i++) {
|
||||
const seed = ((hash * (i + 1) * 7) % 1000) / 1000;
|
||||
logits[i] = seed * 4 - 1; // Range roughly -1 to 3
|
||||
}
|
||||
|
||||
// Make the top prediction more confident
|
||||
const topIndex = Math.abs(hash % (numClasses - 2)) + 1;
|
||||
logits[topIndex] = 3.5;
|
||||
|
||||
// Second highest
|
||||
const secondIndex = (topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1) + 1;
|
||||
logits[secondIndex] = 2.5;
|
||||
|
||||
logits[numClasses - 1] = -2; // "unknown" gets low score
|
||||
|
||||
return logits;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for a given number of milliseconds.
|
||||
*/
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
// ─── Reset (for testing) ─────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Reset the model cache. Useful for testing.
|
||||
*/
|
||||
export function resetModelCache(): void {
|
||||
globalThis.__plantDiseaseModel__ = undefined;
|
||||
globalThis.__plantDiseaseModelLoading__ = undefined;
|
||||
}
|
||||
Reference in New Issue
Block a user