scripting

This commit is contained in:
2026-06-06 15:45:21 -04:00
parent 06295c83ca
commit 47609e5e42
11 changed files with 4411 additions and 205 deletions

View File

@@ -173,14 +173,14 @@ describe("imageToTensor", () => {
describe("tensorToBase64 / base64ToTensor", () => {
it("round-trips tensor data correctly", () => {
const imageData = createMockImageData(224, 224, 100, 150, 200);
const imageData = createMockImageData(160, 160, 100, 150, 200);
const original = imageToTensor(imageData);
const base64 = tensorToBase64(original);
const decoded = base64ToTensor(base64);
expect(decoded.tensor.length).toBe(original.length);
expect(decoded.shape).toEqual([3, 224, 224]);
expect(decoded.shape).toEqual([3, 160, 160]);
// Check a few values match
for (let i = 0; i < 10; i++) {
@@ -197,9 +197,9 @@ describe("tensorToBase64 / base64ToTensor", () => {
});
describe("getTensorShape", () => {
it("returns [1, 3, 224, 224] by default", () => {
it("returns [1, 3, 160, 160] by default", () => {
const shape = getTensorShape();
expect(shape).toEqual([1, 3, 224, 224]);
expect(shape).toEqual([1, 3, 160, 160]);
});
it("returns NCHW layout", () => {
@@ -207,8 +207,8 @@ describe("getTensorShape", () => {
expect(shape.length).toBe(4);
expect(shape[0]).toBe(1); // batch
expect(shape[1]).toBe(3); // channels
expect(shape[2]).toBe(224); // height
expect(shape[3]).toBe(224); // width
expect(shape[2]).toBe(160); // height (model input size)
expect(shape[3]).toBe(160); // width (model input size)
});
});

View File

@@ -17,7 +17,7 @@
const DEFAULT_MODEL_SIZE = 160;
const DEFAULT_MEAN = [0.485, 0.456, 0.406] as const; // ImageNet RGB means
const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds
const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds
function getConfig(): {
size: number;
@@ -48,12 +48,7 @@ export const MAX_FILE_SIZE = 10 * 1024 * 1024;
export const MIN_DIMENSION = 150;
/** Allowed MIME types */
export const ALLOWED_MIME_TYPES = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
] as const;
export const ALLOWED_MIME_TYPES = ["image/png", "image/jpeg", "image/jpg", "image/webp"] as const;
export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number];
@@ -66,9 +61,7 @@ export const MAX_UPLOADS = 100;
* Validate that a file is an acceptable image for upload.
* Returns `{ ok: true }` or `{ ok: false, error: string }`.
*/
export function validateImageFile(file: File):
| { ok: true }
| { ok: false; error: string } {
export function validateImageFile(file: File): { ok: true } | { ok: false; error: string } {
// MIME type check
if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) {
return {
@@ -127,10 +120,7 @@ export function validateImageDimensions(
* @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224.
* @returns ImageData at exactly `size × size`
*/
export async function resizeImage(
file: File,
size: number = getConfig().size,
): Promise<ImageData> {
export async function resizeImage(file: File, size: number = getConfig().size): Promise<ImageData> {
return new Promise((resolve, reject) => {
const img = new Image();
img.onload = () => {
@@ -193,8 +183,7 @@ export function imageToTensor(imageData: ImageData): Float32Array {
// Normalize with ImageNet mean/std
for (let c = 0; c < 3; c++) {
const channel =
c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
const channel = c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
const m = mean[c];
const s = std[c];
for (let i = 0; i < totalPixels; i++) {
@@ -253,5 +242,3 @@ export function base64ToTensor(base64: string): {
shape: envelope.shape as [number, number, number],
};
}

View File

@@ -97,7 +97,7 @@ describe("createZeroTensor", () => {
it("all values are zero", () => {
const tensor = createZeroTensor();
expect(tensor.every(v => v === 0)).toBe(true);
expect(tensor.every((v) => v === 0)).toBe(true);
});
});
@@ -114,12 +114,12 @@ describe("createRandomTensor", () => {
it("all values are finite", () => {
const tensor = createRandomTensor();
expect(tensor.every(v => Number.isFinite(v))).toBe(true);
expect(tensor.every((v) => Number.isFinite(v))).toBe(true);
});
it("produces varied values", () => {
const tensor = createRandomTensor();
const uniqueValues = new Set(tensor.map(v => v.toFixed(4)));
const uniqueValues = new Set(tensor.map((v) => v.toFixed(4)));
expect(uniqueValues.size).toBeGreaterThan(100);
});
@@ -172,7 +172,7 @@ describe("runInference", () => {
const result = await runInference(tensor);
for (let i = 0; i < result.predictions.length - 1; i++) {
expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
result.predictions[i + 1].probability
result.predictions[i + 1].probability,
);
}
}, 10000);

View File

@@ -69,9 +69,7 @@ export async function runInference(
*/
export function validateInput(tensor: Float32Array): void {
if (!(tensor instanceof Float32Array)) {
throw new Error(
`Expected Float32Array input, got ${typeof tensor}`,
);
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
}
if (tensor.length !== INPUT_SIZE) {
@@ -84,9 +82,7 @@ export function validateInput(tensor: Float32Array): void {
// Check for NaN/Infinity values
for (let i = 0; i < tensor.length; i++) {
if (!Number.isFinite(tensor[i])) {
throw new Error(
`Tensor contains non-finite value at index ${i}: ${tensor[i]}`,
);
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
}
}
}

View File

@@ -1,17 +1,21 @@
/**
* Unit tests for lib/ml/labels.ts
*
* Tests:
* - INDEX_TO_DISEASE_ID maps index 0 to "healthy"
* - INDEX_TO_DISEASE_ID maps last index to "unknown"
* - INDEX_TO_DISEASE_ID maps intermediate indices to disease IDs
* - DISEASE_ID_TO_INDEX is inverse of INDEX_TO_DISEASE_ID
* - getDiseaseIdForIndex returns "unknown" for out-of-range
* - getIndexForDiseaseId returns -1 for unknown ID
* - isRealDisease correctly classifies healthy/unknown vs real diseases
* - getAllDiseaseIds returns all disease IDs from knowledge base
* - NUM_CLASSES equals expected count (diseases + healthy + unknown)
* - Bidirectional mapping consistency
* The model has 38 PlantVillage classes. Some map to the app's
* knowledge base disease IDs, others map to "unknown".
*
* Known mappings:
* - indices 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37 → "healthy"
* - index 20 (Potato___Early_blight) → "early-blight"
* - index 21 (Potato___Late_blight) → "late-blight"
* - index 25 (Squash___Powdery_mildew) → "squash-powdery-mildew"
* - index 26 (Strawberry___Leaf_scorch) → "strawberry-leaf-scorch"
* - index 28 (Tomato___Bacterial_spot) → "bacterial-leaf-spot-tomato"
* - index 29 (Tomato___Early_blight) → "early-blight" (duplicate)
* - index 30 (Tomato___Late_blight) → "late-blight" (duplicate)
* - index 32 (Tomato___Septoria_leaf_spot) → "septoria-leaf-spot"
* - index 37 (Tomato___healthy) → "healthy"
* - all others → "unknown"
*/
import { describe, it, expect } from "vitest";
@@ -23,143 +27,105 @@ import {
isRealDisease,
getAllDiseaseIds,
NUM_CLASSES,
HEALTHY_INDEX,
FIRST_DISEASE_INDEX,
UNKNOWN_INDEX,
getPlantVillageClassName,
} from "@/lib/ml/labels";
import rawDiseases from "@/data/diseases.json";
import type { Disease } from "@/lib/types";
const diseases: Disease[] = rawDiseases as Disease[];
describe("Constants", () => {
it("HEALTHY_INDEX is 0", () => {
expect(HEALTHY_INDEX).toBe(0);
it("NUM_CLASSES is 38 (PlantVillage)", () => {
expect(NUM_CLASSES).toBe(38);
});
it("FIRST_DISEASE_INDEX is 1", () => {
expect(FIRST_DISEASE_INDEX).toBe(1);
});
it("UNKNOWN_INDEX is 1 + number of diseases", () => {
expect(UNKNOWN_INDEX).toBe(1 + diseases.length);
});
it("NUM_CLASSES is UNKNOWN_INDEX + 1", () => {
expect(NUM_CLASSES).toBe(UNKNOWN_INDEX + 1);
});
it("NUM_CLASSES equals diseases.length + 2 (healthy + unknown)", () => {
expect(NUM_CLASSES).toBe(diseases.length + 2);
it("all 38 indices are mapped", () => {
const keys = Object.keys(INDEX_TO_DISEASE_ID).map(Number);
expect(keys.length).toBe(38);
for (let i = 0; i < 38; i++) {
expect(keys).toContain(i);
}
});
});
describe("INDEX_TO_DISEASE_ID", () => {
it("maps index 0 to 'healthy'", () => {
expect(INDEX_TO_DISEASE_ID[0]).toBe("healthy");
});
describe("INDEX_TO_DISEASE_ID — healthy indices", () => {
const healthyIndices = [3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37];
it("maps last index to 'unknown'", () => {
expect(INDEX_TO_DISEASE_ID[NUM_CLASSES - 1]).toBe("unknown");
});
for (const idx of healthyIndices) {
it(`index ${idx} maps to "healthy"`, () => {
expect(INDEX_TO_DISEASE_ID[idx]).toBe("healthy");
});
}
});
it("maps intermediate indices to disease IDs", () => {
// Index 1 should be the first disease
expect(INDEX_TO_DISEASE_ID[1]).toBe(diseases[0].id);
// Index 2 should be the second disease
expect(INDEX_TO_DISEASE_ID[2]).toBe(diseases[1].id);
// Last disease index
expect(INDEX_TO_DISEASE_ID[diseases.length]).toBe(diseases[diseases.length - 1].id);
});
describe("INDEX_TO_DISEASE_ID — known disease mappings", () => {
const cases: Array<{ index: number; expected: string; name: string }> = [
{ index: 20, expected: "early-blight", name: "Potato___Early_blight" },
{ index: 21, expected: "late-blight", name: "Potato___Late_blight" },
{ index: 25, expected: "squash-powdery-mildew", name: "Squash___Powdery_mildew" },
{ index: 26, expected: "strawberry-leaf-scorch", name: "Strawberry___Leaf_scorch" },
{ index: 28, expected: "bacterial-leaf-spot-tomato", name: "Tomato___Bacterial_spot" },
{ index: 29, expected: "early-blight", name: "Tomato___Early_blight" },
{ index: 30, expected: "late-blight", name: "Tomato___Late_blight" },
{ index: 32, expected: "septoria-leaf-spot", name: "Tomato___Septoria_leaf_spot" },
];
it("has exactly NUM_CLASSES entries", () => {
const keys = Object.keys(INDEX_TO_DISEASE_ID);
expect(keys.length).toBe(NUM_CLASSES);
});
for (const { index, expected, name } of cases) {
it(`index ${index} (${name}) maps to "${expected}"`, () => {
expect(INDEX_TO_DISEASE_ID[index]).toBe(expected);
});
}
});
it("all mapped IDs are valid strings", () => {
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
expect(typeof id).toBe("string");
expect(id.length).toBeGreaterThan(0);
}
});
describe("INDEX_TO_DISEASE_ID — unknown (unmapped) indices", () => {
const unknownIndices = [0, 1, 2, 5, 7, 8, 9, 11, 12, 13, 15, 16, 18, 31, 33, 34, 35, 36];
for (const idx of unknownIndices) {
it(`index ${idx} maps to "unknown"`, () => {
expect(INDEX_TO_DISEASE_ID[idx]).toBe("unknown");
});
}
});
describe("DISEASE_ID_TO_INDEX", () => {
it("maps 'healthy' to index 0", () => {
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(0);
it("maps 'early-blight' to first occurrence (index 20)", () => {
expect(DISEASE_ID_TO_INDEX["early-blight"]).toBe(20);
});
it("maps 'unknown' to last index", () => {
expect(DISEASE_ID_TO_INDEX["unknown"]).toBe(NUM_CLASSES - 1);
it("maps 'late-blight' to first occurrence (index 21)", () => {
expect(DISEASE_ID_TO_INDEX["late-blight"]).toBe(21);
});
it("maps disease IDs to correct indices", () => {
for (let i = 0; i < diseases.length; i++) {
expect(DISEASE_ID_TO_INDEX[diseases[i].id]).toBe(FIRST_DISEASE_INDEX + i);
}
it("maps 'septoria-leaf-spot' to index 32", () => {
expect(DISEASE_ID_TO_INDEX["septoria-leaf-spot"]).toBe(32);
});
it("has exactly NUM_CLASSES entries", () => {
const keys = Object.keys(DISEASE_ID_TO_INDEX);
expect(keys.length).toBe(NUM_CLASSES);
it("maps 'healthy' to index 3 (first healthy index)", () => {
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(3);
});
});
describe("Bidirectional mapping", () => {
it("INDEX_TO_DISEASE_ID and DISEASE_ID_TO_INDEX are inverses", () => {
for (const [idxStr, id] of Object.entries(INDEX_TO_DISEASE_ID)) {
const idx = parseInt(idxStr);
expect(DISEASE_ID_TO_INDEX[id]).toBe(idx);
}
});
it("round-trips for all disease IDs", () => {
for (const [id, idx] of Object.entries(DISEASE_ID_TO_INDEX)) {
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
}
});
it("round-trips for all indices", () => {
it("every index round-trips correctly", () => {
for (let i = 0; i < NUM_CLASSES; i++) {
const id = INDEX_TO_DISEASE_ID[i];
expect(DISEASE_ID_TO_INDEX[id]).toBe(i);
const idx = DISEASE_ID_TO_INDEX[id];
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
}
});
});
describe("getDiseaseIdForIndex", () => {
it("returns 'healthy' for index 0", () => {
expect(getDiseaseIdForIndex(0)).toBe("healthy");
});
it("returns disease ID for valid disease index", () => {
expect(getDiseaseIdForIndex(1)).toBe(diseases[0].id);
});
it("returns 'unknown' for out-of-range positive index", () => {
expect(getDiseaseIdForIndex(1000)).toBe("unknown");
expect(getDiseaseIdForIndex(100)).toBe("unknown");
});
it("returns 'unknown' for negative index", () => {
expect(getDiseaseIdForIndex(-1)).toBe("unknown");
});
it("returns 'unknown' for index past NUM_CLASSES", () => {
expect(getDiseaseIdForIndex(NUM_CLASSES + 10)).toBe("unknown");
it("returns correct ID for valid index", () => {
expect(getDiseaseIdForIndex(20)).toBe("early-blight");
});
});
describe("getIndexForDiseaseId", () => {
it("returns 0 for 'healthy'", () => {
expect(getIndexForDiseaseId("healthy")).toBe(0);
});
it("returns correct index for known disease", () => {
const idx = getIndexForDiseaseId(diseases[0].id);
expect(idx).toBe(1);
});
it("returns -1 for unknown disease ID", () => {
expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1);
});
@@ -169,9 +135,7 @@ describe("getIndexForDiseaseId", () => {
});
it("is case-insensitive", () => {
const lowerIdx = getIndexForDiseaseId(diseases[0].id);
const upperIdx = getIndexForDiseaseId(diseases[0].id.toUpperCase());
expect(upperIdx).toBe(lowerIdx);
expect(getIndexForDiseaseId("EARLY-BLIGHT")).toBe(20);
});
});
@@ -184,10 +148,9 @@ describe("isRealDisease", () => {
expect(isRealDisease("unknown")).toBe(false);
});
it("returns true for actual disease IDs", () => {
for (const disease of diseases) {
expect(isRealDisease(disease.id)).toBe(true);
}
it("returns true for known disease IDs", () => {
expect(isRealDisease("early-blight")).toBe(true);
expect(isRealDisease("septoria-leaf-spot")).toBe(true);
});
it("returns true for arbitrary non-special strings", () => {
@@ -195,27 +158,37 @@ describe("isRealDisease", () => {
});
});
describe("getPlantVillageClassName", () => {
it("returns correct class name for tomato healthy", () => {
expect(getPlantVillageClassName(37)).toBe("Tomato___healthy");
});
it("returns correct class name for potato early blight", () => {
expect(getPlantVillageClassName(20)).toBe("Potato___Early_blight");
});
it("returns 'unknown' for out-of-range index", () => {
expect(getPlantVillageClassName(100)).toBe("unknown");
});
});
describe("getAllDiseaseIds", () => {
it("returns array of all disease IDs", () => {
it("returns only mapped disease IDs", () => {
const ids = getAllDiseaseIds();
expect(ids.length).toBe(diseases.length);
expect(ids).toContain("early-blight");
expect(ids).toContain("late-blight");
expect(ids).toContain("squash-powdery-mildew");
expect(ids).toContain("strawberry-leaf-scorch");
expect(ids).toContain("bacterial-leaf-spot-tomato");
expect(ids).toContain("septoria-leaf-spot");
});
it("excludes 'healthy'", () => {
const ids = getAllDiseaseIds();
expect(ids).not.toContain("healthy");
expect(getAllDiseaseIds()).not.toContain("healthy");
});
it("excludes 'unknown'", () => {
const ids = getAllDiseaseIds();
expect(ids).not.toContain("unknown");
});
it("includes all disease IDs from knowledge base", () => {
const ids = getAllDiseaseIds();
for (const disease of diseases) {
expect(ids).toContain(disease.id);
}
expect(getAllDiseaseIds()).not.toContain("unknown");
});
it("has no duplicates", () => {

View File

@@ -1,74 +1,197 @@
/**
* Class label mapping for the plant disease classifier model.
*
* Maps model output index → disease ID string.
* The model has classes for each disease in the knowledge base,
* plus "healthy" and "unknown" catch-all classes.
* This model is a MobileNetV2 trained on the PlantVillage dataset
* with 38 classes (14 crops × diseases/healthy).
*
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 95
* (93 diseases + "healthy" + "unknown")
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 38
*
* Index layout:
* 0 → "healthy"
* 193 → disease IDs (order matches diseases.json)
* 94"unknown"
* Index layout (from labels_pv_original.json):
* 0 → Apple___Apple_scab
* 1 → Apple___Black_rot
* 2 Apple___Cedar_apple_rust
* 3 → Apple___healthy
* 4 → Blueberry___healthy
* 5 → Cherry_(including_sour)___Powdery_mildew
* 6 → Cherry_(including_sour)___healthy
* 7 → Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
* 8 → Corn_(maize)___Common_rust_
* 9 → Corn_(maize)___Northern_Leaf_Blight
* 10 → Corn_(maize)___healthy
* 11 → Grape___Black_rot
* 12 → Grape___Esca_(Black_Measles)
* 13 → Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
* 14 → Grape___healthy
* 15 → Orange___Haunglongbing_(Citrus_greening)
* 16 → Peach___Bacterial_spot
* 17 → Peach___healthy
* 18 → Pepper,_bell___Bacterial_spot
* 19 → Pepper,_bell___healthy
* 20 → Potato___Early_blight
* 21 → Potato___Late_blight
* 22 → Potato___healthy
* 23 → Raspberry___healthy
* 24 → Soybean___healthy
* 25 → Squash___Powdery_mildew
* 26 → Strawberry___Leaf_scorch
* 27 → Strawberry___healthy
* 28 → Tomato___Bacterial_spot
* 29 → Tomato___Early_blight
* 30 → Tomato___Late_blight
* 31 → Tomato___Leaf_Mold
* 32 → Tomato___Septoria_leaf_spot
* 33 → Tomato___Spider_mites Two-spotted_spider_mite
* 34 → Tomato___Target_Spot
* 35 → Tomato___Tomato_Yellow_Leaf_Curl_Virus
* 36 → Tomato___Tomato_mosaic_virus
* 37 → Tomato___healthy
*
* Some PlantVillage classes overlap with this app's knowledge base.
* Exact class name → disease ID mappings:
* Potato___Early_blight → "early-blight"
* Potato___Late_blight → "late-blight"
* Squash___Powdery_mildew → "squash-powdery-mildew"
* Strawberry___Leaf_scorch → "strawberry-leaf-scorch"
* Tomato___Bacterial_spot → "bacterial-leaf-spot-tomato"
* Tomato___Early_blight → "early-blight"
* Tomato___Late_blight → "late-blight"
* Tomato___Septoria_leaf_spot → "septoria-leaf-spot"
* All other classes map to "unknown" and are filtered out during enrichment.
*
* After fine-tuning to the app's 93 disease classes, this file will be
* rewritten to match the new model's output layer.
*/
import rawDiseases from "@/data/diseases.json";
import type { Disease } from "@/lib/types";
// ─── PlantVillage class names (in model output order) ────────────────────
const diseases: Disease[] = rawDiseases as Disease[];
const PLANTVILLAGE_CLASSES: string[] = [
"Apple___Apple_scab",
"Apple___Black_rot",
"Apple___Cedar_apple_rust",
"Apple___healthy",
"Blueberry___healthy",
"Cherry_(including_sour)___Powdery_mildew",
"Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
"Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___healthy",
"Grape___Black_rot",
"Grape___Esca_(Black_Measles)",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
"Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot",
"Peach___healthy",
"Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy",
"Potato___Early_blight",
"Potato___Late_blight",
"Potato___healthy",
"Raspberry___healthy",
"Soybean___healthy",
"Squash___Powdery_mildew",
"Strawberry___Leaf_scorch",
"Strawberry___healthy",
"Tomato___Bacterial_spot",
"Tomato___Early_blight",
"Tomato___Late_blight",
"Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus",
"Tomato___healthy",
] as const;
// ─── Constants ───────────────────────────────────────────────────────────────
/** Index for the "healthy" class */
export const HEALTHY_INDEX = 0;
/** First index for actual disease classes */
export const FIRST_DISEASE_INDEX = 1;
/** Index for the "unknown" catch-all class */
export const UNKNOWN_INDEX = 1 + diseases.length;
/** Total number of output classes */
export const NUM_CLASSES = UNKNOWN_INDEX + 1;
// ─── Index → Disease ID mapping ──────────────────────────────────────────────
// ─── PlantVillage → App disease ID mapping ──────────────────────────────
/**
* Map from model output index to disease ID string.
* Index 0 = "healthy", indices 1..N = disease IDs, last = "unknown".
* Maps PlantVillage class names (in the form "Plant___Disease") to
* this app's disease IDs. Unmapped classes resolve to "unknown".
*/
function plantVillageNameToDiseaseId(pvName: string): string {
const parts = pvName.split("___");
if (parts.length !== 2) {
return "unknown";
}
const disease = parts[1];
// Detect "healthy" variants
if (disease === "healthy") {
return "healthy";
}
// Map exact PlantVillage class names to our disease IDs.
// Only map classes where we're confident the correspondence holds.
const exactMap: Record<string, string> = {
Squash___Powdery_mildew: "squash-powdery-mildew",
Strawberry___Leaf_scorch: "strawberry-leaf-scorch",
Potato___Early_blight: "early-blight",
Potato___Late_blight: "late-blight",
Tomato___Bacterial_spot: "bacterial-leaf-spot-tomato",
Tomato___Early_blight: "early-blight",
Tomato___Late_blight: "late-blight",
Tomato___Septoria_leaf_spot: "septoria-leaf-spot",
};
return exactMap[pvName] ?? "unknown";
}
// ─── Constants ──────────────────────────────────────────────────────────
/** Total number of model output classes */
export const NUM_CLASSES = PLANTVILLAGE_CLASSES.length; // 38
/** Index for the "healthy" class — multiple PV indices map to this */
export const HEALTHY_INDEX = 0; // First PV healthy class, others also map to this string
/** First disease index (unused in PV mapping, kept for compatibility) */
export const FIRST_DISEASE_INDEX = 0;
/** Index for the "unknown" catch-all — PV classes we can't map */
export const UNKNOWN_INDEX = NUM_CLASSES - 1; // 37 (Tomato___healthy maps to "healthy", not unknown)
// ─── Index → Disease ID mapping ─────────────────────────────────────────
/**
* Map from model output index to app disease ID string.
* Built dynamically from PlantVillage class names.
*/
export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze(
(() => {
const map: Record<number, string> = {};
map[HEALTHY_INDEX] = "healthy";
for (let i = 0; i < diseases.length; i++) {
map[FIRST_DISEASE_INDEX + i] = diseases[i].id;
for (let i = 0; i < NUM_CLASSES; i++) {
map[i] = plantVillageNameToDiseaseId(PLANTVILLAGE_CLASSES[i]);
}
map[UNKNOWN_INDEX] = "unknown";
return map;
})(),
);
// ─── Disease ID → Index mapping ──────────────────────────────────────────────
// ─── Disease ID → Index mapping ─────────────────────────────────────────
/**
* Map from disease ID string to model output index.
* For duplicates (e.g., both potato and tomato "Early_blight" → "early-blight"),
* returns the first matching index.
*/
export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze(
(() => {
const map: Record<string, number> = {};
map["healthy"] = HEALTHY_INDEX;
for (let i = 0; i < diseases.length; i++) {
map[diseases[i].id] = FIRST_DISEASE_INDEX + i;
for (let i = 0; i < NUM_CLASSES; i++) {
const id = INDEX_TO_DISEASE_ID[i];
// First occurrence wins (potato before tomato for early/late blight)
if (map[id] === undefined) {
map[id] = i;
}
}
map["unknown"] = UNKNOWN_INDEX;
return map;
})(),
);
// ─── Lookup helpers ──────────────────────────────────────────────────────────
// ─── Lookup helpers ─────────────────────────────────────────────────────
/**
* Get the disease ID for a given model output index.
@@ -93,9 +216,22 @@ export function isRealDisease(diseaseId: string): boolean {
return diseaseId !== "healthy" && diseaseId !== "unknown";
}
/**
* Get the PlantVillage display name for a given model output index.
*/
export function getPlantVillageClassName(index: number): string {
return PLANTVILLAGE_CLASSES[index] ?? "unknown";
}
/**
* Get all known disease IDs (excluding "healthy" and "unknown").
*/
export function getAllDiseaseIds(): string[] {
return diseases.map((d) => d.id);
const ids = new Set<string>();
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
if (id !== "healthy" && id !== "unknown") {
ids.add(id);
}
}
return Array.from(ids);
}

View File

@@ -93,7 +93,10 @@ export async function getModel(): Promise<PlantDiseaseModel> {
const model = await Promise.race([
loadingPromise,
new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)), MODEL_LOAD_TIMEOUT),
setTimeout(
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
MODEL_LOAD_TIMEOUT,
),
),
]);
@@ -172,6 +175,18 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let tf: any;
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
// eslint-disable-next-line @typescript-eslint/no-require-imports
const nodeUtil = require("util");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
return x === null || x === undefined;
};
}
// Try tfjs-node first (server-side, uses native bindings).
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
// as required dependencies — they are truly optional.
@@ -197,7 +212,9 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
const startTime = performance.now();
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 160, 160])
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
const inputTensor = tf
.tensor3d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0])
.expandDims(0);
@@ -352,7 +369,7 @@ function generateMockLogits(tensor: Float32Array): Float32Array {
logits[topIndex] = 3.5;
// Second highest
const secondIndex = (topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1) + 1;
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
logits[secondIndex] = 2.5;
logits[numClasses - 1] = -2; // "unknown" gets low score