Files
plant-disease-id/apps/web/src/lib/ml/inference.test.ts
Michael Freno 820a872f07 Initial commit: Plant Disease Identification app
- Next.js 16 App Router project with Tailwind CSS
- Plant disease knowledge base (93 diseases, 25 plants)
- Image upload with client+server preprocessing
- ML inference pipeline with mock/demo fallback
- Responsive results page with disease cards and treatment
- Full test suite (285 passing tests)
2026-06-05 19:21:16 -04:00

245 lines
7.6 KiB
TypeScript

/**
* Unit tests for lib/ml/inference.ts
*
* Tests:
* - validateInput rejects non-Float32Array
* - validateInput rejects wrong-length arrays
* - validateInput rejects NaN/Infinity values
* - validateInput accepts correct tensor
* - createZeroTensor produces correct shape
* - createRandomTensor produces correct shape with finite values
* - runInference returns InferenceResult with predictions array
* - runInference returns exactly top-K predictions
* - runInference predictions are sorted descending
* - runInference includes inferenceTimeMs
* - runInference completes under 3 seconds
* - runBatchInference processes multiple images
*/
import { describe, it, expect, beforeEach } from "vitest";
import {
runInference,
validateInput,
createZeroTensor,
createRandomTensor,
runBatchInference,
INPUT_SIZE,
INPUT_SHAPE,
DEFAULT_TOP_K,
} from "@/lib/ml/inference";
import { resetModelCache } from "@/lib/ml/model-loader";
describe("validateInput", () => {
it("rejects non-Float32Array", () => {
expect(() => validateInput([1, 2, 3] as any)).toThrow("Expected Float32Array input");
});
it("rejects wrong-length arrays", () => {
const tensor = new Float32Array(100);
expect(() => validateInput(tensor)).toThrow(`Expected tensor of length ${INPUT_SIZE}`);
});
it("rejects NaN values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = NaN;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("rejects Infinity values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = Infinity;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("rejects -Infinity values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = -Infinity;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("accepts correct tensor", () => {
const tensor = createZeroTensor();
expect(() => validateInput(tensor)).not.toThrow();
});
it("accepts tensor with negative values", () => {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < INPUT_SIZE; i++) {
tensor[i] = -2;
}
expect(() => validateInput(tensor)).not.toThrow();
});
it("accepts tensor with values near zero", () => {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < INPUT_SIZE; i++) {
tensor[i] = 0.0001;
}
expect(() => validateInput(tensor)).not.toThrow();
});
});
describe("createZeroTensor", () => {
it("produces Float32Array", () => {
const tensor = createZeroTensor();
expect(tensor).toBeInstanceOf(Float32Array);
});
it("has correct length", () => {
const tensor = createZeroTensor();
expect(tensor.length).toBe(INPUT_SIZE);
});
it("has correct shape dimensions", () => {
const expectedLength = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3];
expect(INPUT_SIZE).toBe(expectedLength);
});
it("all values are zero", () => {
const tensor = createZeroTensor();
expect(tensor.every(v => v === 0)).toBe(true);
});
});
describe("createRandomTensor", () => {
it("produces Float32Array", () => {
const tensor = createRandomTensor();
expect(tensor).toBeInstanceOf(Float32Array);
});
it("has correct length", () => {
const tensor = createRandomTensor();
expect(tensor.length).toBe(INPUT_SIZE);
});
it("all values are finite", () => {
const tensor = createRandomTensor();
expect(tensor.every(v => Number.isFinite(v))).toBe(true);
});
it("produces varied values", () => {
const tensor = createRandomTensor();
const uniqueValues = new Set(tensor.map(v => v.toFixed(4)));
expect(uniqueValues.size).toBeGreaterThan(100);
});
it("passes validateInput", () => {
const tensor = createRandomTensor();
expect(() => validateInput(tensor)).not.toThrow();
});
});
describe("INPUT_SHAPE and INPUT_SIZE", () => {
it("INPUT_SHAPE is [1, 3, 224, 224]", () => {
expect(INPUT_SHAPE).toEqual([1, 3, 224, 224]);
});
it("INPUT_SIZE equals 3 * 224 * 224", () => {
expect(INPUT_SIZE).toBe(3 * 224 * 224);
});
it("DEFAULT_TOP_K is 5", () => {
expect(DEFAULT_TOP_K).toBe(5);
});
});
describe("runInference", () => {
beforeEach(() => {
resetModelCache();
});
it("returns InferenceResult with predictions array", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.predictions).toBeDefined();
expect(Array.isArray(result.predictions)).toBe(true);
}, 10000);
it("returns exactly top-K predictions by default", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
}, 10000);
it("returns custom top-K predictions", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor, 3);
expect(result.predictions.length).toBe(3);
}, 10000);
it("predictions are sorted by probability descending", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
for (let i = 0; i < result.predictions.length - 1; i++) {
expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
result.predictions[i + 1].probability
);
}
}, 10000);
it("includes inferenceTimeMs", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.inferenceTimeMs).toBeDefined();
expect(typeof result.inferenceTimeMs).toBe("number");
expect(result.inferenceTimeMs).toBeGreaterThan(0);
}, 10000);
it("completes under 3 seconds", async () => {
const tensor = createRandomTensor();
const start = performance.now();
const result = await runInference(tensor);
const elapsed = performance.now() - start;
expect(elapsed).toBeLessThan(3000);
expect(result.inferenceTimeMs).toBeLessThan(3000);
}, 10000);
it("each prediction has classIndex and probability", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
for (const pred of result.predictions) {
expect(pred.classIndex).toBeDefined();
expect(typeof pred.classIndex).toBe("number");
expect(pred.probability).toBeDefined();
expect(typeof pred.probability).toBe("number");
expect(pred.probability).toBeGreaterThanOrEqual(0);
expect(pred.probability).toBeLessThanOrEqual(1);
}
}, 10000);
it("throws on invalid input", async () => {
const badTensor = new Float32Array(100);
await expect(runInference(badTensor)).rejects.toThrow();
});
});
describe("runBatchInference", () => {
beforeEach(() => {
resetModelCache();
});
it("processes multiple images", async () => {
const tensors = [createRandomTensor(), createRandomTensor(), createRandomTensor()];
const results = await runBatchInference(tensors);
expect(results).toHaveLength(3);
for (const result of results) {
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
expect(result.inferenceTimeMs).toBeGreaterThan(0);
}
}, 30000);
it("each result is independent", async () => {
const tensors = [createRandomTensor(), createRandomTensor()];
const results = await runBatchInference(tensors);
// Results should differ (different random inputs → different predictions)
expect(results[0].predictions[0].classIndex).toBeDefined();
expect(results[1].predictions[0].classIndex).toBeDefined();
}, 15000);
it("accepts custom top-K", async () => {
const tensors = [createRandomTensor()];
const results = await runBatchInference(tensors, 3);
expect(results[0].predictions.length).toBe(3);
}, 15000);
});