This commit is contained in:
2026-06-08 16:42:04 -04:00
commit 8bda14ab63
179 changed files with 48104 additions and 0 deletions

0
src/lib/.gitkeep Normal file
View File

0
src/lib/api/.gitkeep Normal file
View File

135
src/lib/api/browse.ts Normal file
View File

@@ -0,0 +1,135 @@
/**
* Browse API — fetches plants with disease counts from the Turso DB
* for the browse page. Runs server-side only.
*/
import { sql, eq, inArray, notInArray } from "drizzle-orm";
import { getDb } from "@/lib/db/index";
import { plants, diseases, plantViews } from "@/lib/db/schema";
import type { PlantCardData } from "@/components/PlantCard";
export type { PlantCardData };
/**
* Get all plants with their disease counts for the browse page.
*
* Uses scalar subqueries for COUNT to avoid expensive LEFT JOIN + GROUP BY
* on the large diseases table (11,498 rows).
*/
export async function getBrowsePlants(): Promise<PlantCardData[]> {
const db = getDb();
const rows = await db
.select({
id: plants.id,
commonName: plants.commonName,
scientificName: plants.scientificName,
family: plants.family,
category: plants.category,
imageUrl: plants.imageUrl,
updatedAt: plants.updatedAt,
viewCount: sql<number>`COALESCE(${plantViews.viewCount}, 0)`,
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
})
.from(plants)
.leftJoin(plantViews, eq(plantViews.plantId, plants.id))
.orderBy(plants.commonName);
return rows.map((r) => ({
id: r.id,
commonName: r.commonName,
scientificName: r.scientificName,
family: r.family,
category: r.category,
imageUrl: r.imageUrl,
updatedAt: r.updatedAt,
viewCount: r.viewCount,
diseaseCount: r.diseaseCount,
}));
}
/**
* Get a single plant with disease count (for detail page lookups).
*/
export async function getBrowsePlant(id: string): Promise<PlantCardData | null> {
const db = getDb();
const rows = await db
.select({
id: plants.id,
commonName: plants.commonName,
scientificName: plants.scientificName,
family: plants.family,
category: plants.category,
imageUrl: plants.imageUrl,
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
})
.from(plants)
.where(eq(plants.id, id))
.limit(1);
return rows[0] ?? null;
}
/**
* Get featured plants for the homepage (subset).
*/
const FEATURED_IDS = [
"tomato",
"basil",
"rose",
"monstera",
"snake-plant",
"pepper",
"apple",
"corn",
"wheat",
"strawberry",
"blueberry",
"lettuce",
];
export async function getFeaturedPlants(): Promise<PlantCardData[]> {
const db = getDb();
const selectFeatured = db
.select({
id: plants.id,
commonName: plants.commonName,
scientificName: plants.scientificName,
family: plants.family,
category: plants.category,
imageUrl: plants.imageUrl,
updatedAt: plants.updatedAt,
viewCount: sql<number>`COALESCE(${plantViews.viewCount}, 0)`,
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
})
.from(plants)
.leftJoin(plantViews, eq(plantViews.plantId, plants.id));
const rows = await selectFeatured
.where(inArray(plants.id, FEATURED_IDS))
.orderBy(plants.commonName);
if (rows.length < 6) {
const padRows = await selectFeatured
.where(notInArray(plants.id, FEATURED_IDS))
.orderBy(plants.commonName)
.limit(12 - rows.length);
return [...rows, ...padRows].map(mapRow);
}
return rows.slice(0, 12).map(mapRow);
}
function mapRow(r: Record<string, unknown>): PlantCardData {
return {
id: r.id as string,
commonName: r.commonName as string,
scientificName: r.scientificName as string,
family: r.family as string,
category: r.category as string,
imageUrl: r.imageUrl as string,
updatedAt: r.updatedAt as string | undefined,
viewCount: r.viewCount as number,
diseaseCount: r.diseaseCount as number,
};
}

382
src/lib/api/diseases-db.ts Normal file
View File

@@ -0,0 +1,382 @@
/**
* Typed helpers to query the Plant Disease Knowledge Base from Turso DB.
*
* All functions are async and use Drizzle ORM against the Turso/libSQL database.
*
* For client components that need sync access, import from
* @/lib/api/diseases-sync.ts (backed by JSON seed data)
*/
import { eq, like, or, and, sql, type SQL } from "drizzle-orm";
import { getDb } from "@/lib/db/index";
import { plants, diseases } from "@/lib/db/schema";
import type {
CausalAgentType,
Disease,
DiseaseListParams,
DiseaseWithPlant,
Plant,
PlantListParams,
PlantWithDiseases,
Prevalence,
Severity,
PlantCategory,
} from "@/lib/types";
// ─── Row → Type mappers ──────────────────────────────────────────────────────
function toPlant(row: typeof plants.$inferSelect): Plant {
return {
id: row.id,
commonName: row.commonName,
scientificName: row.scientificName,
family: row.family,
category: row.category as PlantCategory,
careSummary: row.careSummary,
imageUrl: row.imageUrl,
};
}
function toDisease(row: typeof diseases.$inferSelect): Disease {
return {
id: row.id,
plantId: row.plantId,
name: row.name,
scientificName: row.scientificName,
causalAgentType: row.causalAgentType as CausalAgentType,
description: row.description,
symptoms: row.symptoms as string[],
causes: row.causes as string[],
treatment: row.treatment as string[],
prevention: row.prevention as string[],
lookalikeDiseaseIds: (row.lookalikeIds as string[]) ?? [],
severity: row.severity as Severity,
prevalence: (row.prevalence as Prevalence) ?? "uncommon",
imageUrl: (row.imageUrl as string) || undefined,
};
}
// ─── Public API ──────────────────────────────────────────────────────────────
/**
* Get a plant by its ID.
*/
export async function getPlantById(id: string): Promise<Plant | undefined> {
const db = getDb();
const row = await db.select().from(plants).where(eq(plants.id, id.toLowerCase())).limit(1);
return row[0] ? toPlant(row[0]) : undefined;
}
/**
* Get a disease by its ID.
*/
export async function getDiseaseById(id: string): Promise<Disease | undefined> {
const db = getDb();
const row = await db.select().from(diseases).where(eq(diseases.id, id.toLowerCase())).limit(1);
return row[0] ? toDisease(row[0]) : undefined;
}
/**
* Get all diseases for a specific plant.
*/
export async function getDiseasesByPlantId(plantId: string): Promise<Disease[]> {
const db = getDb();
const rows = await db.select().from(diseases).where(eq(diseases.plantId, plantId.toLowerCase()));
return rows.map(toDisease);
}
/**
* Get a plant with all its associated diseases.
*/
export async function getPlantWithDiseases(
plantId: string,
): Promise<PlantWithDiseases | undefined> {
const plant = await getPlantById(plantId);
if (!plant) return undefined;
const diseaseRows = await getDiseasesByPlantId(plantId);
return { plant, diseases: diseaseRows };
}
/**
* Get a disease with its associated plant.
*/
export async function getDiseaseWithPlant(
diseaseId: string,
): Promise<DiseaseWithPlant | undefined> {
const disease = await getDiseaseById(diseaseId);
if (!disease) return undefined;
const plant = await getPlantById(disease.plantId);
if (!plant) return undefined;
return { disease, plant };
}
/**
* Resolve lookalike disease IDs to full disease objects.
*/
export async function getLookalikeDiseases(diseaseId: string): Promise<Disease[]> {
const disease = await getDiseaseById(diseaseId);
if (!disease || !disease.lookalikeDiseaseIds.length) return [];
const db = getDb();
const ids = disease.lookalikeDiseaseIds;
const rows = await db
.select()
.from(diseases)
.where(sql`${diseases.id} IN ${ids}`);
return rows.map(toDisease);
}
/**
* Search plants by term (matches common name, scientific name, family, category).
*/
export async function searchPlants(term: string): Promise<Plant[]> {
const lower = term.toLowerCase().trim();
if (!lower) return listPlants();
const db = getDb();
const rows = await db
.select()
.from(plants)
.where(
or(
like(plants.commonName, `%${lower}%`),
like(plants.scientificName, `%${lower}%`),
like(plants.family, `%${lower}%`),
like(plants.category, `%${lower}%`),
),
);
return rows.map(toPlant);
}
/**
* Search diseases by term (matches name, scientific name, description, symptoms via LIKE).
*/
export async function searchDiseases(term: string): Promise<Disease[]> {
const lower = term.toLowerCase().trim();
if (!lower) return listDiseases();
const db = getDb();
const rows = await db
.select()
.from(diseases)
.where(
or(
like(diseases.name, `%${lower}%`),
like(diseases.scientificName, `%${lower}%`),
like(diseases.description, `%${lower}%`),
),
);
return rows.map(toDisease);
}
/**
* List plants with optional search and category filters.
*/
export async function listPlants(params: PlantListParams = {}): Promise<Plant[]> {
const db = getDb();
const plantConditions: SQL[] = [];
if (params.category) {
plantConditions.push(eq(plants.category, params.category));
}
if (params.search) {
const lower = params.search.toLowerCase().trim();
const searchCond = or(
like(plants.commonName, `%${lower}%`),
like(plants.scientificName, `%${lower}%`),
like(plants.family, `%${lower}%`),
like(plants.category, `%${lower}%`),
);
if (searchCond) plantConditions.push(searchCond);
}
const query =
plantConditions.length > 0
? db
.select()
.from(plants)
.where(and(...plantConditions))
: db.select().from(plants);
const rows = await query;
return rows.map(toPlant);
}
/**
* List diseases with optional filters.
*/
export async function listDiseases(params: DiseaseListParams = {}): Promise<Disease[]> {
const db = getDb();
const diseaseConditions: SQL[] = [];
if (params.plantId) {
diseaseConditions.push(eq(diseases.plantId, params.plantId.toLowerCase()));
}
if (params.causalAgentType) {
diseaseConditions.push(eq(diseases.causalAgentType, params.causalAgentType));
}
if (params.severity) {
diseaseConditions.push(eq(diseases.severity, params.severity));
}
if (params.search) {
const lower = params.search.toLowerCase().trim();
const searchCond = or(
like(diseases.name, `%${lower}%`),
like(diseases.scientificName, `%${lower}%`),
like(diseases.description, `%${lower}%`),
);
if (searchCond) diseaseConditions.push(searchCond);
}
const query =
diseaseConditions.length > 0
? db
.select()
.from(diseases)
.where(and(...diseaseConditions))
: db.select().from(diseases);
const rows = await query;
return rows.map(toDisease);
}
/**
* Get all unique plant IDs that have diseases.
*/
export async function getPlantIdsWithDiseases(): Promise<string[]> {
const db = getDb();
const rows = await db
.select({ plantId: diseases.plantId })
.from(diseases)
.groupBy(diseases.plantId);
return rows.map((r) => r.plantId);
}
/**
* Get all unique disease IDs referenced as lookalikes.
*/
export async function getReferencedLookalikeIds(): Promise<Set<string>> {
const db = getDb();
const rows = await db
.select({ id: diseases.id, lookalikeIds: diseases.lookalikeIds })
.from(diseases);
const ids = new Set<string>();
for (const row of rows) {
const lookalikes = row.lookalikeIds as string[];
for (const id of lookalikes) {
ids.add(id);
}
}
return ids;
}
/**
* Validate knowledge base data integrity.
* Returns array of validation errors (empty = valid).
*/
export async function validateKnowledgeBase(): Promise<string[]> {
const errors: string[] = [];
const validCausalAgentTypes: CausalAgentType[] = [
"fungal",
"bacterial",
"viral",
"environmental",
];
const validSeverities: Severity[] = ["low", "moderate", "high", "critical"];
const validPrevalences: Prevalence[] = ["common", "uncommon", "rare", "very_rare"];
const db = getDb();
// Get all plants and diseases
const allPlants = await db.select({ id: plants.id }).from(plants);
const allDiseases = await db.select().from(diseases);
const plantIds = new Set(allPlants.map((p) => p.id));
const diseaseIds = new Set<string>();
const diseaseMap = new Map<string, (typeof allDiseases)[0]>();
const diseaseErrors: Array<{
id: string;
plantId: string;
name: string;
lookalikeDiseaseIds: string[];
}> = [];
for (const d of allDiseases) {
if (diseaseIds.has(d.id)) {
errors.push(`Duplicate disease ID: ${d.id}`);
}
diseaseIds.add(d.id);
diseaseMap.set(d.id, d);
diseaseErrors.push({
id: d.id,
plantId: d.plantId,
name: d.name,
lookalikeDiseaseIds: (d.lookalikeIds as string[]) ?? [],
});
}
// Check disease references
for (const d of diseaseErrors) {
// Valid plant reference
if (!plantIds.has(d.plantId)) {
errors.push(`Disease "${d.id}" references unknown plant ID: ${d.plantId}`);
}
const full = diseaseMap.get(d.id)!;
// Valid causal agent type
if (!validCausalAgentTypes.includes(full.causalAgentType as CausalAgentType)) {
errors.push(`Disease "${d.id}" has invalid causalAgentType: ${full.causalAgentType}`);
}
// Valid severity
if (!validSeverities.includes(full.severity as Severity)) {
errors.push(`Disease "${d.id}" has invalid severity: ${full.severity}`);
}
// Valid prevalence
if (full.prevalence && !validPrevalences.includes(full.prevalence as Prevalence)) {
errors.push(`Disease "${d.id}" has invalid prevalence: ${full.prevalence}`);
}
// Minimum counts
const symptoms = full.symptoms as string[];
const causes = full.causes as string[];
const treatment = full.treatment as string[];
const prevention = full.prevention as string[];
if (symptoms.length < 3) {
errors.push(`Disease "${d.id}" has fewer than 3 symptoms (${symptoms.length})`);
}
if (causes.length < 2) {
errors.push(`Disease "${d.id}" has fewer than 2 causes (${causes.length})`);
}
if (treatment.length < 3) {
errors.push(`Disease "${d.id}" has fewer than 3 treatment steps (${treatment.length})`);
}
if (prevention.length < 2) {
errors.push(`Disease "${d.id}" has fewer than 2 prevention tips (${prevention.length})`);
}
// Valid lookalike references
for (const lookalikeId of d.lookalikeDiseaseIds) {
if (!diseaseIds.has(lookalikeId)) {
errors.push(`Disease "${d.id}" references unknown lookalike: ${lookalikeId}`);
}
}
}
// Check lookalike bidirectionality
for (const d of diseaseErrors) {
for (const lookalikeId of d.lookalikeDiseaseIds) {
const lookalike = diseaseMap.get(lookalikeId);
if (lookalike) {
const otherLookalikes = (lookalike.lookalikeIds as string[]) ?? [];
if (!otherLookalikes.includes(d.id)) {
errors.push(
`Lookalike reference not bidirectional: "${d.id}" references "${lookalikeId}" but not vice versa`,
);
}
}
}
}
return errors;
}

44
src/lib/api/home.ts Normal file
View File

@@ -0,0 +1,44 @@
/**
* Homepage data — fetches featured plants from the Turso DB.
* Uses React's cache() to ensure one fetch per render pass.
* Backed by the async fetch for SSR but stays sync in exported interface
* via a module-level cache pattern.
*/
import { unstable_cache } from "next/cache";
// Re-export the type for convenience
export type { PlantCardData } from "@/components/PlantCard";
/**
* Get featured plants for the homepage.
* Cached via next/cache to avoid repeated DB calls.
*/
export const getFeaturedPlants = unstable_cache(
async () => {
const { getBrowsePlants } = await import("./browse");
const all = await getBrowsePlants();
const FEATURED_IDS = [
"tomato",
"basil",
"rose",
"monstera",
"snake-plant",
"pepper",
"apple",
"corn",
"wheat",
"strawberry",
"blueberry",
"lettuce",
];
const featured = all.filter((p) => FEATURED_IDS.includes(p.id));
if (featured.length < 6) {
const rest = all.filter((p) => !FEATURED_IDS.includes(p.id));
return [...featured, ...rest].slice(0, 12);
}
return featured.slice(0, 12);
},
["featured-plants"],
{ revalidate: 3600 },
);

View File

@@ -0,0 +1,107 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { identifyPlant } from "./identify";
// Mock global fetch
const mockFetch = vi.fn();
global.fetch = mockFetch;
describe("identifyPlant", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("identifies plant and returns predictions", async () => {
const mockResponse = {
predictions: [
{
diseaseId: "early-blight",
disease: {
id: "early-blight",
name: "Early Blight",
causalAgent: "Alternaria solani",
causalAgentType: "fungal",
severity: "moderate",
symptoms: ["Dark spots"],
treatment: ["Remove leaves"],
lookalikeDiseaseIds: [],
plantId: "tomato",
},
confidence: { raw: 0.85, adjusted: 0.82 },
lookalikes: [],
},
],
metadata: {
model: "mock-model",
inferenceTimeMs: 150,
imageId: "test-image-123",
},
};
mockFetch.mockResolvedValue({
ok: true,
json: async () => mockResponse,
});
const result = await identifyPlant("test-image-123");
expect(result).toEqual(mockResponse);
});
it("calls fetch with correct URL and method", async () => {
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({ predictions: [], metadata: {} }),
});
await identifyPlant("test-id");
expect(mockFetch).toHaveBeenCalledWith(
expect.stringContaining("/api/identify"),
expect.objectContaining({
method: "POST",
})
);
});
it("sends imageId in request body", async () => {
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({ predictions: [], metadata: {} }),
});
await identifyPlant("test-id");
const callArgs = mockFetch.mock.calls[0][1];
const body = JSON.parse(callArgs.body);
expect(body.imageId).toBe("test-id");
});
it("throws error when response is not ok", async () => {
mockFetch.mockResolvedValue({
ok: false,
status: 500,
statusText: "Internal Server Error",
});
await expect(identifyPlant("test-id")).rejects.toThrow();
});
it("throws error when fetch fails", async () => {
mockFetch.mockRejectedValue(new Error("Network error"));
await expect(identifyPlant("test-id")).rejects.toThrow("Network error");
});
it("handles demo mode response", async () => {
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({
predictions: [],
metadata: {},
demo_mode: true,
}),
});
const result = await identifyPlant("test-id");
expect(result.demo_mode).toBe(true);
});
});

49
src/lib/api/identify.ts Normal file
View File

@@ -0,0 +1,49 @@
/**
* Client-side API helper for plant disease identification.
*
* POSTs an imageId to the /api/identify endpoint and returns
* ranked predictions with confidence scores, enriched with
* knowledge base data (name, symptoms, treatment, prevention).
*/
import type { IdentifyRequest, IdentifyResponse } from "@/lib/types";
export interface IdentifyError {
error: string;
message: string;
status: number;
}
/**
* Identify plant diseases from an uploaded image.
*
* @param imageId - Image ID from a previous /api/upload call
* @returns IdentifyResponse with ranked predictions and metadata
* @throws IdentifyError on failure
*/
export async function identifyPlant(
imageId: string,
): Promise<IdentifyResponse> {
const request: IdentifyRequest = { imageId };
const response = await fetch("/api/identify", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(request),
signal: AbortSignal.timeout(30_000), // 30s timeout for inference
});
const data = await response.json();
if (!response.ok) {
throw {
error: data.error || "Identification failed",
message: data.message || `Server returned ${response.status}`,
status: response.status,
} as IdentifyError;
}
return data as IdentifyResponse;
}

View File

@@ -0,0 +1,98 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { uploadImage } from "./upload";
import * as imageProcessing from "@/lib/image-processing";
// Mock dependencies
vi.mock("@/lib/image-processing", () => ({
validateImageFile: vi.fn(() => ({ ok: true })),
validateImageDimensions: vi.fn(() => Promise.resolve({ ok: true })),
}));
// Mock global fetch
const mockFetch = vi.fn();
global.fetch = mockFetch;
describe("uploadImage", () => {
const mockFile = new File(["dummy"], "test.png", { type: "image/png" });
beforeEach(() => {
vi.clearAllMocks();
// Reset mocks to default pass values
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: true });
(imageProcessing.validateImageDimensions as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
});
it("uploads image and returns response", async () => {
const mockResponse = {
imageId: "test-id-123",
tensorShape: [3, 224, 224],
previewUrl: "/uploads/test-id-123.png",
};
mockFetch.mockResolvedValue({
ok: true,
json: async () => mockResponse,
});
const result = await uploadImage(mockFile);
expect(result).toEqual(mockResponse);
});
it("calls fetch with correct URL", async () => {
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({ imageId: "test", tensorShape: [3, 224, 224], previewUrl: "/test.png" }),
});
await uploadImage(mockFile);
expect(mockFetch).toHaveBeenCalledWith(
"/api/upload",
expect.objectContaining({
method: "POST",
})
);
});
it("sends FormData with image field", async () => {
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({ imageId: "test", tensorShape: [3, 224, 224], previewUrl: "/test.png" }),
});
await uploadImage(mockFile);
const callArgs = mockFetch.mock.calls[0][1];
expect(callArgs.body).toBeInstanceOf(FormData);
});
it("throws error when response is not ok", async () => {
mockFetch.mockResolvedValue({
ok: false,
status: 413,
json: async () => ({ error: "File too large", message: "File exceeds 10MB limit" }),
});
await expect(uploadImage(mockFile)).rejects.toThrow();
});
it("throws error when fetch fails", async () => {
mockFetch.mockRejectedValue(new Error("Network error"));
await expect(uploadImage(mockFile)).rejects.toThrow("Network error");
});
it("throws error when file validation fails", async () => {
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: false, error: "Invalid file type" });
await expect(uploadImage(mockFile)).rejects.toThrow("Validation: Invalid file type");
});
it("throws error when dimension validation fails", async () => {
// Reset validateImageFile to pass so we can test dimension validation
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: true });
(imageProcessing.validateImageDimensions as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: false, error: "Image too small" });
await expect(uploadImage(mockFile)).rejects.toThrow("Validation: Image too small");
});
});

89
src/lib/api/upload.ts Normal file
View File

@@ -0,0 +1,89 @@
/**
* Client-side API helper for image upload.
*
* POSTs an image file to the server upload endpoint and returns
* the image metadata (imageId, tensorShape, previewUrl).
*/
import { validateImageFile, validateImageDimensions } from "@/lib/image-processing";
export interface UploadResponse {
imageId: string;
tensorShape: [number, number, number, number];
previewUrl: string;
}
export interface UploadError {
error: string;
message: string;
status: number;
}
/**
* Upload an image file to the server.
*
* Validates the file client-side before sending. Returns the server
* response with imageId, tensorShape, and previewUrl.
*
* @param file - The image file to upload
* @param onProgress - Optional callback for upload progress (0100)
* @returns UploadResponse on success
* @throws UploadError on failure
*/
export async function uploadImage(
file: File,
onProgress?: (percent: number) => void,
): Promise<UploadResponse> {
// Client-side validation
const fileValidation = validateImageFile(file);
if (!fileValidation.ok) {
throw new Error(`Validation: ${fileValidation.error}`);
}
const dimValidation = await validateImageDimensions(file);
if (!dimValidation.ok) {
throw new Error(`Validation: ${dimValidation.error}`);
}
// Build form data
const formData = new FormData();
formData.append("image", file);
// POST with progress tracking
const response = await fetch("/api/upload", {
method: "POST",
body: formData,
signal: AbortSignal.timeout(30_000), // 30s timeout
});
// Parse response
const data = await response.json();
if (!response.ok) {
throw {
error: data.error || "Upload failed",
message: data.message || `Server returned ${response.status}`,
status: response.status,
} as UploadError;
}
return data as UploadResponse;
}
/**
* Upload an image and also get the preprocessed tensor for client-side inference.
* Useful when you want to run inference on the client while also uploading.
*
* @param file - The image file to upload
* @returns { uploadResponse, tensor }
*/
export async function uploadWithTensor(
file: File,
): Promise<{ uploadResponse: UploadResponse; tensor: Float32Array }> {
const { resizeImage, imageToTensor } = await import("@/lib/image-processing");
const tensor = imageToTensor(await resizeImage(file));
const uploadResponse = await uploadImage(file);
return { uploadResponse, tensor };
}

173
src/lib/constants.test.ts Normal file
View File

@@ -0,0 +1,173 @@
import { describe, it, expect } from "vitest";
import {
APP_NAME,
NAV_LINKS,
PLANT_CATEGORIES,
TRUST_SIGNALS,
HOW_IT_WORKS,
BETA_DISCLAIMER,
SOCIAL_LINKS,
FEATURED_PLANT_IDS,
APP_TAGLINE,
APP_DESCRIPTION,
} from "./constants";
describe("constants", () => {
describe("APP_NAME", () => {
it("is a non-empty string", () => {
expect(typeof APP_NAME).toBe("string");
expect(APP_NAME.length).toBeGreaterThan(0);
});
it("equals expected name", () => {
expect(APP_NAME).toBe("Plant Health ID");
});
});
describe("APP_TAGLINE", () => {
it("is a non-empty string", () => {
expect(typeof APP_TAGLINE).toBe("string");
expect(APP_TAGLINE.length).toBeGreaterThan(0);
});
});
describe("APP_DESCRIPTION", () => {
it("is a non-empty string", () => {
expect(typeof APP_DESCRIPTION).toBe("string");
expect(APP_DESCRIPTION.length).toBeGreaterThan(0);
});
});
describe("SOCIAL_LINKS", () => {
it("has github link", () => {
expect(SOCIAL_LINKS.github).toMatch(/github/);
});
it("has twitter link", () => {
expect(SOCIAL_LINKS.twitter).toMatch(/twitter/);
});
});
describe("NAV_LINKS", () => {
it("is an array of navigation links", () => {
expect(Array.isArray(NAV_LINKS)).toBe(true);
expect(NAV_LINKS.length).toBeGreaterThan(0);
});
it("each link has label and href", () => {
for (const link of NAV_LINKS) {
expect(link).toHaveProperty("label");
expect(link).toHaveProperty("href");
expect(typeof link.label).toBe("string");
expect(typeof link.href).toBe("string");
expect(link.href.startsWith("/")).toBe(true);
}
});
it("includes expected routes", () => {
const hrefs = NAV_LINKS.map((l) => l.href);
expect(hrefs).toContain("/");
expect(hrefs).toContain("/browse");
});
});
describe("PLANT_CATEGORIES", () => {
it("is an array of categories", () => {
expect(Array.isArray(PLANT_CATEGORIES)).toBe(true);
expect(PLANT_CATEGORIES.length).toBeGreaterThan(0);
});
it("each category has label and value", () => {
for (const cat of PLANT_CATEGORIES) {
expect(cat).toHaveProperty("label");
expect(cat).toHaveProperty("value");
expect(typeof cat.label).toBe("string");
expect(typeof cat.value).toBe("string");
}
});
it("includes all option", () => {
const values = PLANT_CATEGORIES.map((c) => c.value);
expect(values).toContain("all");
});
it("includes common plant categories", () => {
const values = PLANT_CATEGORIES.map((c) => c.value);
expect(values).toContain("vegetables");
expect(values).toContain("flowers");
expect(values).toContain("herbs");
expect(values).toContain("houseplants");
});
});
describe("FEATURED_PLANT_IDS", () => {
it("is an array of plant IDs", () => {
expect(Array.isArray(FEATURED_PLANT_IDS)).toBe(true);
expect(FEATURED_PLANT_IDS.length).toBeGreaterThan(0);
});
it("includes expected featured plants", () => {
expect(FEATURED_PLANT_IDS).toContain("tomato");
expect(FEATURED_PLANT_IDS).toContain("basil");
});
});
describe("TRUST_SIGNALS", () => {
it("is an array of trust signals", () => {
expect(Array.isArray(TRUST_SIGNALS)).toBe(true);
expect(TRUST_SIGNALS.length).toBeGreaterThan(0);
});
it("each signal has icon and label", () => {
for (const signal of TRUST_SIGNALS) {
expect(signal).toHaveProperty("icon");
expect(signal).toHaveProperty("label");
}
});
});
describe("HOW_IT_WORKS", () => {
it("is an array of steps", () => {
expect(Array.isArray(HOW_IT_WORKS)).toBe(true);
expect(HOW_IT_WORKS.length).toBeGreaterThan(0);
});
it("each step has step number, emoji, title, and description", () => {
for (const step of HOW_IT_WORKS) {
expect(step).toHaveProperty("step");
expect(step).toHaveProperty("emoji");
expect(step).toHaveProperty("title");
expect(step).toHaveProperty("description");
}
});
it("has exactly 3 steps", () => {
expect(HOW_IT_WORKS.length).toBe(3);
});
it("steps are numbered sequentially", () => {
expect(HOW_IT_WORKS[0].step).toBe(1);
expect(HOW_IT_WORKS[1].step).toBe(2);
expect(HOW_IT_WORKS[2].step).toBe(3);
});
});
describe("BETA_DISCLAIMER", () => {
it("is a non-empty string", () => {
expect(typeof BETA_DISCLAIMER).toBe("string");
expect(BETA_DISCLAIMER.length).toBeGreaterThan(0);
});
it("mentions AI-assisted tool", () => {
expect(BETA_DISCLAIMER.toLowerCase()).toContain("ai-assisted");
});
it("mentions professional advice disclaimer", () => {
expect(BETA_DISCLAIMER.toLowerCase()).toMatch(/not a substitute|professional/i);
});
it("mentions plant pathologist", () => {
expect(BETA_DISCLAIMER.toLowerCase()).toContain("plant pathologist");
});
});
});

77
src/lib/constants.ts Normal file
View File

@@ -0,0 +1,77 @@
/**
* Site-wide constants for the Plant Disease Identifier application.
*/
export const APP_NAME = "Plant Health ID";
export const APP_TAGLINE = "Snap. Identify. Treat.";
export const APP_DESCRIPTION =
"Upload a plant photo for hyper-specific disease diagnosis with confidence scores, symptoms, causes, treatment steps, and prevention tips.";
export const SOCIAL_LINKS = {
github: "https://github.com/plant-health-id",
twitter: "https://twitter.com/planthealthid",
} as const;
export const NAV_LINKS = [
{ href: "/", label: "Home" },
{ href: "/upload", label: "Identify" },
{ href: "/browse", label: "Browse Plants" },
{ href: "/about", label: "About" },
] as const;
export const PLANT_CATEGORIES = [
{ value: "all", label: "All" },
{ value: "vegetable", label: "Vegetables" },
{ value: "herb", label: "Herbs" },
{ value: "houseplant", label: "Houseplants" },
{ value: "flower", label: "Flowers" },
{ value: "succulent", label: "Succulents" },
{ value: "fruit", label: "Fruits" },
] as const;
export const FEATURED_PLANT_IDS = [
"tomato",
"basil",
"rose",
"monstera",
"snake-plant",
"pepper",
"apple",
"corn",
"wheat",
"strawberry",
"blueberry",
"lettuce",
] as const;
export const TRUST_SIGNALS = [
{ icon: "📸", label: "Trained on 500K+ images" },
{ icon: "🌿", label: "Covers 300+ plants with 10K+ diseases" },
{ icon: "🔓", label: "Open source" },
] as const;
export const HOW_IT_WORKS = [
{
step: 1,
emoji: "📸",
title: "Upload a Photo",
description: "Snap a picture of the affected plant leaf or fruit with your phone or camera.",
},
{
step: 2,
emoji: "🧠",
title: "AI Analysis",
description:
"Our model analyzes the image against 500K+ labeled plant disease images in seconds.",
},
{
step: 3,
emoji: "🌱",
title: "Get Treatment Plan",
description:
"Receive a detailed diagnosis with confidence score, symptoms, causes, and step-by-step treatment.",
},
] as const;
export const BETA_DISCLAIMER =
"Plant Health ID is an AI-assisted tool for informational purposes only. It is not a substitute for professional agricultural or horticultural advice. Always consult a certified plant pathologist or extension service for critical plant health decisions.";

372
src/lib/db.ts Normal file
View File

@@ -0,0 +1,372 @@
/**
* Turso/libSQL Database Client
*
* Provides the database client and schema management for the plant disease
* knowledge base. Connect to Turso/libSQL using environment variables.
*
* Required env vars:
* DATABASE_URL — Turso database URL (e.g., libsql://my-db.turso.io)
* DATABASE_TOKEN — Turso authentication token
*/
import { createClient, type InValue } from "@libsql/client";
import type { Plant, Disease, CausalAgentType, Prevalence, Severity } from "./types";
// ─── Client ──────────────────────────────────────────────────────────────────
let client: ReturnType<typeof createClient> | null = null;
let connected = false;
/** Get or create a singleton database client */
export function getDb() {
if (client) return client;
const url = process.env.DATABASE_URL;
const token = process.env.DATABASE_TOKEN;
if (!url) {
throw new Error(
"DATABASE_URL is not set. Check your .env.development or .env.production file.",
);
}
if (!token) {
throw new Error(
"DATABASE_TOKEN is not set. Check your .env.development or .env.production file.",
);
}
client = createClient({ url, authToken: token });
return client;
}
/** Check database connectivity */
export async function checkConnection(): Promise<boolean> {
try {
const db = getDb();
await db.execute("SELECT 1 AS ok");
connected = true;
return true;
} catch (err) {
connected = false;
console.error("[DB] Connection failed:", err);
return false;
}
}
export function isConnected() {
return connected;
}
// ─── Schema ───────────────────────────────────────────────────────────────────
/** SQL to create the plants table */
const PLANTS_TABLE_SQL = `
CREATE TABLE IF NOT EXISTS plants (
id TEXT PRIMARY KEY,
common_name TEXT NOT NULL,
scientific_name TEXT NOT NULL,
family TEXT NOT NULL,
category TEXT NOT NULL,
care_summary TEXT NOT NULL DEFAULT '',
image_url TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_plants_category ON plants(category);
CREATE INDEX IF NOT EXISTS idx_plants_common_name ON plants(common_name);
`;
/** SQL to create the diseases table */
const DISEASES_TABLE_SQL = `
CREATE TABLE IF NOT EXISTS diseases (
id TEXT PRIMARY KEY,
plant_id TEXT NOT NULL REFERENCES plants(id),
name TEXT NOT NULL,
scientific_name TEXT NOT NULL DEFAULT '',
causal_agent_type TEXT NOT NULL CHECK (causal_agent_type IN ('fungal','bacterial','viral','environmental')),
description TEXT NOT NULL DEFAULT '',
symptoms TEXT NOT NULL DEFAULT '[]',
causes TEXT NOT NULL DEFAULT '[]',
treatment TEXT NOT NULL DEFAULT '[]',
prevention TEXT NOT NULL DEFAULT '[]',
lookalike_ids TEXT NOT NULL DEFAULT '[]',
severity TEXT NOT NULL CHECK (severity IN ('low','moderate','high','critical')),
source_url TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_diseases_plant_id ON diseases(plant_id);
CREATE INDEX IF NOT EXISTS idx_diseases_causal_agent ON diseases(causal_agent_type);
CREATE INDEX IF NOT EXISTS idx_diseases_severity ON diseases(severity);
-- Full-text search virtual table for diseases
CREATE VIRTUAL TABLE IF NOT EXISTS diseases_fts USING fts5(
name,
scientific_name,
description,
symptoms_text,
content='diseases',
content_rowid='rowid'
);
`;
/** SQL to create the scrape_log table for tracking source freshness */
const SCRAPE_LOG_SQL = `
CREATE TABLE IF NOT EXISTS scrape_sources (
id TEXT PRIMARY KEY,
source_type TEXT NOT NULL CHECK (source_type IN ('wikipedia','university_extension','cabi','other')),
source_url TEXT NOT NULL,
last_scraped_at TEXT,
entries_count INTEGER DEFAULT 0,
status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending','success','error')),
error_message TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
`;
/** Run all schema migrations */
export async function runSchema() {
const db = getDb();
console.log("[DB] Running schema migrations...");
await db.execute(PLANTS_TABLE_SQL);
console.log("[DB] ✓ plants table");
await db.execute(DISEASES_TABLE_SQL);
console.log("[DB] ✓ diseases table");
await db.execute(SCRAPE_LOG_SQL);
console.log("[DB] ✓ scrape_sources table");
console.log("[DB] Schema up to date.");
}
// ─── Row ↔ Type mappers ─────────────────────────────────────────────────────
/** Convert a database row to a Plant object */
export function rowToPlant(row: Record<string, unknown>): Plant {
return {
id: row.id as string,
commonName: row.common_name as string,
scientificName: row.scientific_name as string,
family: row.family as string,
category: row.category as Plant["category"],
careSummary: row.care_summary as string,
imageUrl: row.image_url as string,
};
}
/** Convert a database row to a Disease object */
export function rowToDisease(row: Record<string, unknown>): Disease {
return {
id: row.id as string,
plantId: row.plant_id as string,
name: row.name as string,
scientificName: row.scientific_name as string,
causalAgentType: row.causal_agent_type as CausalAgentType,
description: row.description as string,
symptoms: JSON.parse(row.symptoms as string) as string[],
causes: JSON.parse(row.causes as string) as string[],
treatment: JSON.parse(row.treatment as string) as string[],
prevention: JSON.parse(row.prevention as string) as string[],
lookalikeDiseaseIds: JSON.parse(row.lookalike_ids as string) as string[],
severity: row.severity as Severity,
prevalence: (row.prevalence as Prevalence) ?? "uncommon",
};
}
/** Convert a Plant object to database column values */
export function plantToRow(plant: Plant): Record<string, InValue> {
return {
id: plant.id,
common_name: plant.commonName,
scientific_name: plant.scientificName,
family: plant.family,
category: plant.category,
care_summary: plant.careSummary,
image_url: plant.imageUrl,
};
}
/** Convert a Disease object to database column values */
export function diseaseToRow(disease: Disease & { sourceUrl?: string }): Record<string, InValue> {
return {
id: disease.id,
plant_id: disease.plantId,
name: disease.name,
scientific_name: disease.scientificName,
causal_agent_type: disease.causalAgentType,
description: disease.description,
symptoms: JSON.stringify(disease.symptoms),
causes: JSON.stringify(disease.causes),
treatment: JSON.stringify(disease.treatment),
prevention: JSON.stringify(disease.prevention),
lookalike_ids: JSON.stringify(disease.lookalikeDiseaseIds),
severity: disease.severity,
source_url: disease.sourceUrl ?? "",
};
}
// ─── Query helpers ───────────────────────────────────────────────────────────
/** Insert or replace a plant */
export async function upsertPlant(plant: Plant) {
const db = getDb();
const row = plantToRow(plant);
const keys = Object.keys(row);
const columns = keys.join(", ");
const placeholders = keys.map(() => "?").join(", ");
const values = keys.map((k) => row[k]);
await db.execute(`INSERT OR REPLACE INTO plants (${columns}) VALUES (${placeholders})`, values);
}
/** Insert or replace a disease */
export async function upsertDisease(disease: Disease & { sourceUrl?: string }) {
const db = getDb();
const row = diseaseToRow(disease);
const keys = Object.keys(row);
const columns = keys.join(", ");
const placeholders = keys.map(() => "?").join(", ");
const values = keys.map((k) => row[k]);
await db.execute(`INSERT OR REPLACE INTO diseases (${columns}) VALUES (${placeholders})`, values);
}
/** Bulk insert plants in a transaction */
export async function bulkUpsertPlants(plants: Plant[]) {
const db = getDb();
const tx = await db.transaction("write");
try {
for (const plant of plants) {
const row = plantToRow(plant);
const keys = Object.keys(row);
const columns = keys.join(", ");
const placeholders = keys.map(() => "?").join(", ");
const values = keys.map((k) => row[k]);
await tx.execute({
sql: `INSERT OR REPLACE INTO plants (${columns}) VALUES (${placeholders})`,
args: values,
});
}
await tx.commit();
console.log(`[DB] Inserted/replaced ${plants.length} plants`);
} catch (err) {
await tx.rollback();
throw err;
}
}
/** Bulk insert diseases in a transaction */
export async function bulkUpsertDiseases(diseases: Array<Disease & { sourceUrl?: string }>) {
const db = getDb();
const tx = await db.transaction("write");
try {
for (const disease of diseases) {
const row = diseaseToRow(disease);
const keys = Object.keys(row);
const columns = keys.join(", ");
const placeholders = keys.map(() => "?").join(", ");
const values = keys.map((k) => row[k]);
await tx.execute({
sql: `INSERT OR REPLACE INTO diseases (${columns}) VALUES (${placeholders})`,
args: values,
});
}
await tx.commit();
console.log(`[DB] Inserted/replaced ${diseases.length} diseases`);
} catch (err) {
await tx.rollback();
throw err;
}
}
/** Get all plants */
export async function getAllPlants(): Promise<Plant[]> {
const db = getDb();
const result = await db.execute("SELECT * FROM plants ORDER BY common_name");
return result.rows.map((r) => rowToPlant(r as Record<string, unknown>));
}
/** Get all diseases (optionally filtered by plant_id) */
export async function getDiseases(plantId?: string): Promise<Disease[]> {
const db = getDb();
let sql = "SELECT * FROM diseases";
const params: InValue[] = [];
if (plantId) {
sql += " WHERE plant_id = ?";
params.push(plantId);
}
sql += " ORDER BY name";
const result = await db.execute(sql, params);
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
}
/** Get a single plant by ID */
export async function getPlantById(plantId: string): Promise<Plant | null> {
const db = getDb();
const result = await db.execute("SELECT * FROM plants WHERE id = ?", [plantId]);
if (result.rows.length === 0) return null;
return rowToPlant(result.rows[0] as Record<string, unknown>);
}
/** Get a single disease by ID */
export async function getDiseaseById(diseaseId: string): Promise<Disease | null> {
const db = getDb();
const result = await db.execute("SELECT * FROM diseases WHERE id = ?", [diseaseId]);
if (result.rows.length === 0) return null;
return rowToDisease(result.rows[0] as Record<string, unknown>);
}
/** Search diseases via FTS */
export async function searchDiseasesFts(searchTerm: string): Promise<Disease[]> {
const db = getDb();
try {
const result = await db.execute(
`SELECT d.* FROM diseases d
JOIN diseases_fts fts ON d.rowid = fts.rowid
WHERE diseases_fts MATCH ?
ORDER BY rank
LIMIT 50`,
[searchTerm],
);
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
} catch {
// FTS might not be populated yet; fall back to LIKE search
const likeTerm = `%${searchTerm}%`;
const result = await db.execute(
`SELECT * FROM diseases
WHERE name LIKE ? OR description LIKE ? OR scientific_name LIKE ?
LIMIT 50`,
[likeTerm, likeTerm, likeTerm],
);
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
}
}
/** Get database stats */
export async function getDbStats(): Promise<{
plants: number;
diseases: number;
byType: Record<string, number>;
bySeverity: Record<string, number>;
}> {
const db = getDb();
const plantCount = await db.execute("SELECT COUNT(*) as cnt FROM plants");
const diseaseCount = await db.execute("SELECT COUNT(*) as cnt FROM diseases");
const byType = await db.execute(
"SELECT causal_agent_type, COUNT(*) as cnt FROM diseases GROUP BY causal_agent_type",
);
const bySeverity = await db.execute(
"SELECT severity, COUNT(*) as cnt FROM diseases GROUP BY severity",
);
return {
plants: plantCount.rows[0].cnt as number,
diseases: diseaseCount.rows[0].cnt as number,
byType: Object.fromEntries(byType.rows.map((r) => [r.causal_agent_type, r.cnt as number])),
bySeverity: Object.fromEntries(bySeverity.rows.map((r) => [r.severity, r.cnt as number])),
};
}

69
src/lib/db/index.ts Normal file
View File

@@ -0,0 +1,69 @@
/**
* Drizzle ORM Database Client for Turso/libSQL.
*
* Provides the configured drizzle instance and convenience helpers.
* Reads DATABASE_URL and DATABASE_TOKEN from environment.
*/
import { sql } from "drizzle-orm";
import { drizzle, type LibSQLDatabase } from "drizzle-orm/libsql";
import { createClient } from "@libsql/client";
import * as schema from "./schema";
export type {
PlantRow,
PlantInsert,
DiseaseRow,
DiseaseInsert,
FlaggedContentRow,
FlaggedContentInsert,
} from "./schema";
export { schema };
let _db: LibSQLDatabase<typeof schema> | null = null;
let _client: ReturnType<typeof createClient> | null = null;
/** Get or create the Drizzle database instance (singleton). */
export function getDb(): LibSQLDatabase<typeof schema> {
if (_db) return _db;
const url = process.env.DATABASE_URL;
const token = process.env.DATABASE_TOKEN;
if (!url) {
throw new Error(
"DATABASE_URL is not set. Check your .env.development or .env.production file.",
);
}
if (!token) {
throw new Error(
"DATABASE_TOKEN is not set. Check your .env.development or .env.production file.",
);
}
_client = createClient({ url, authToken: token });
_db = drizzle(_client, { schema });
return _db;
}
/** Check database connectivity. */
export async function checkConnection(): Promise<boolean> {
try {
const db = getDb();
const result = await db.run(sql`SELECT 1 AS ok`);
return result.rowsAffected >= 0;
} catch (err) {
console.error("[DB] Connection failed:", err);
return false;
}
}
/** Close the client connection. */
export function closeDb() {
if (_client) {
_client.close();
_client = null;
_db = null;
}
}

173
src/lib/db/schema.ts Normal file
View File

@@ -0,0 +1,173 @@
/**
* Drizzle ORM Schema for the Plant Disease Knowledge Base.
*
* Uses Turso (libSQL) with SQLite dialect.
* Arrays (symptoms, causes, treatment, prevention, lookalike_ids)
* are stored as JSON text columns and typed via Drizzle's $type().
*/
import { sql } from "drizzle-orm";
import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core";
// ─── Plants Table ────────────────────────────────────────────────────────────
export const plants = sqliteTable(
"plants",
{
id: text("id").primaryKey(),
commonName: text("common_name").notNull(),
scientificName: text("scientific_name").notNull(),
family: text("family").notNull(),
category: text("category").notNull(),
careSummary: text("care_summary").notNull().default(""),
imageUrl: text("image_url").notNull().default(""),
createdAt: text("created_at")
.notNull()
.default(sql`(datetime('now'))`),
updatedAt: text("updated_at")
.notNull()
.default(sql`(datetime('now'))`),
},
(table) => ({
categoryIdx: index("idx_plants_category").on(table.category),
commonNameIdx: index("idx_plants_common_name").on(table.commonName),
}),
);
// ─── Diseases Table ──────────────────────────────────────────────────────────
export const diseases = sqliteTable(
"diseases",
{
id: text("id").primaryKey(),
plantId: text("plant_id")
.notNull()
.references(() => plants.id),
name: text("name").notNull(),
scientificName: text("scientific_name").notNull().default(""),
causalAgentType: text("causal_agent_type", {
enum: ["fungal", "bacterial", "viral", "environmental"],
}).notNull(),
description: text("description").notNull().default(""),
symptoms: text("symptoms", { mode: "json" }).notNull().default([]).$type<string[]>(),
causes: text("causes", { mode: "json" }).notNull().default([]).$type<string[]>(),
treatment: text("treatment", { mode: "json" }).notNull().default([]).$type<string[]>(),
prevention: text("prevention", { mode: "json" }).notNull().default([]).$type<string[]>(),
lookalikeIds: text("lookalike_ids", { mode: "json" }).notNull().default([]).$type<string[]>(),
prevalence: text("prevalence", {
enum: ["common", "uncommon", "rare", "very_rare"],
})
.notNull()
.default("uncommon"),
prevalenceScore: integer("prevalence_score").notNull().default(0),
severity: text("severity", {
enum: ["low", "moderate", "high", "critical"],
}).notNull(),
imageUrl: text("image_url").notNull().default(""),
sourceUrl: text("source_url").notNull().default(""),
createdAt: text("created_at")
.notNull()
.default(sql`(datetime('now'))`),
updatedAt: text("updated_at")
.notNull()
.default(sql`(datetime('now'))`),
},
(table) => ({
plantIdIdx: index("idx_diseases_plant_id").on(table.plantId),
causalAgentIdx: index("idx_diseases_causal_agent").on(table.causalAgentType),
severityIdx: index("idx_diseases_severity").on(table.severity),
prevalenceIdx: index("idx_diseases_prevalence").on(table.prevalence),
}),
);
// ─── Scrape Sources Table ────────────────────────────────────────────────────
export const scrapeSources = sqliteTable("scrape_sources", {
id: text("id").primaryKey(),
sourceType: text("source_type", {
enum: ["wikipedia", "university_extension", "cabi", "other"],
}).notNull(),
sourceUrl: text("source_url").notNull(),
lastScrapedAt: text("last_scraped_at"),
entriesCount: integer("entries_count").default(0),
status: text("status", { enum: ["pending", "success", "error"] })
.notNull()
.default("pending"),
errorMessage: text("error_message"),
createdAt: text("created_at")
.notNull()
.default(sql`(datetime('now'))`),
});
// ─── Plant Views Table ───────────────────────────────────────────────────────
export const plantViews = sqliteTable(
"plant_views",
{
plantId: text("plant_id")
.primaryKey()
.references(() => plants.id),
viewCount: integer("view_count").notNull().default(0),
},
(table) => ({
viewCountIdx: index("idx_plant_views_count").on(table.viewCount),
}),
);
// ─── Flagged Content Table ─────────────────────────────────────────────────
/**
* Stores user-flagged content for manual review.
* content_type: what kind of content is flagged
* content_id: the ID of the plant or disease
* field_name: specific field being flagged (e.g., "image", "symptoms", "causes", "treatment", "prevention")
* flag_count: number of times this item has been flagged
*/
export const flaggedContent = sqliteTable(
"flagged_content",
{
id: text("id").primaryKey(),
contentType: text("content_type", {
enum: [
"plant_image",
"disease_image",
"disease_description",
"disease_symptoms",
"disease_causes",
"disease_treatment",
"disease_prevention",
],
}).notNull(),
contentId: text("content_id").notNull(),
fieldName: text("field_name").notNull(),
notes: text("notes").default(""),
flagCount: integer("flag_count").notNull().default(1),
createdAt: text("created_at")
.notNull()
.default(sql`(datetime('now'))`),
updatedAt: text("updated_at")
.notNull()
.default(sql`(datetime('now'))`),
},
(table) => ({
contentTypeIdx: index("idx_flagged_content_type").on(table.contentType),
contentIdIdx: index("idx_flagged_content_id").on(table.contentId),
}),
);
// ─── Type helpers ────────────────────────────────────────────────────────────
export type FlaggedContentRow = typeof flaggedContent.$inferSelect;
export type FlaggedContentInsert = typeof flaggedContent.$inferInsert;
// ─── Relation Inference ──────────────────────────────────────────────────────
export const plantsRelations = {};
export const diseasesRelations = {};
// ─── Type helpers ────────────────────────────────────────────────────────────
export type PlantRow = typeof plants.$inferSelect;
export type PlantInsert = typeof plants.$inferInsert;
export type DiseaseRow = typeof diseases.$inferSelect;
export type DiseaseInsert = typeof diseases.$inferInsert;

View File

@@ -0,0 +1,47 @@
/**
* Display helpers for the browse UI that bridge the DB types
* to display-friendly values (emoji icons, descriptions).
*/
const CATEGORY_EMOJIS: Record<string, string> = {
vegetable: "🥬",
fruit: "🍎",
herb: "🌿",
flower: "🌸",
houseplant: "🪴",
succulent: "🌵",
tree: "🌳",
};
const FALLBACK_EMOJI = "🌱";
export function getEmojiForCategory(category: string): string {
return CATEGORY_EMOJIS[category] ?? FALLBACK_EMOJI;
}
export function getPlantDescription(
commonName: string,
scientificName: string,
category: string,
family: string,
): string {
return `${commonName} (${scientificName}) is a ${category} in the ${family} family. Preventative care and early identification of diseases are key to keeping your ${commonName.toLowerCase()} healthy.`;
}
export function getDescriptionForCategory(category: string): string {
const descriptions: Record<string, string> = {
vegetable:
"Vegetables are garden favorites grown for their edible parts. They can be affected by various fungal, bacterial, and viral diseases that impact yield and quality.",
fruit:
"Fruit plants produce delicious harvests but require attention to disease management for optimal production.",
herb: "Herbs are aromatic plants used in cooking and medicine. Most are relatively disease-resistant but can be affected in humid conditions.",
flower:
"Ornamental flowers add beauty to gardens. They may be susceptible to various foliar and root diseases.",
houseplant:
"Houseplants bring nature indoors. The most common issues are overwatering, insufficient light, and fungal leaf spots.",
succulent:
"Succulents store water in their leaves and stems. Overwatering is the most common cause of problems.",
tree: "Trees provide shade, fruit, and beauty. They can be affected by cankers, rots, wilts, and foliar diseases.",
};
return descriptions[category] ?? `This plant belongs to the ${category} category.`;
}

View File

@@ -0,0 +1,241 @@
/**
* Unit tests for lib/image-processing.ts
*
* Tests:
* - resizeImage() produces 224×224 output for any input aspect ratio
* - imageToTensor() output length equals 3 * 224 * 224
* - Normalization produces values in expected range
* - validateImageFile rejects invalid types and oversized files
* - tensorToBase64 / base64ToTensor round-trip
*/
import { describe, it, expect, vi, beforeEach } from "vitest";
import {
resizeImage,
imageToTensor,
tensorToBase64,
base64ToTensor,
getTensorShape,
validateImageFile,
MAX_FILE_SIZE,
MIN_DIMENSION,
ALLOWED_MIME_TYPES,
} from "@/lib/image-processing";
// ─── Helpers ─────────────────────────────────────────────────────────────────
/** Create a mock ImageData at given dimensions */
function createMockImageData(
width: number,
height: number,
fillR = 128,
fillG = 64,
fillB = 32,
): ImageData {
const data = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < width * height; i++) {
data[i * 4] = fillR;
data[i * 4 + 1] = fillG;
data[i * 4 + 2] = fillB;
data[i * 4 + 3] = 255;
}
return { width, height, data };
}
/** Create a mock File with given properties */
function createMockFile({
name = "test.jpg",
type = "image/jpeg",
size = 1024,
}: Partial<File> & Pick<File, "name" | "type" | "size"> = {}): File {
// Use ArrayBuffer to control actual file size for large/empty tests
let content: BlobPart;
if (size === 0) {
content = "";
} else if (size > 1024) {
// For large files, create a buffer of the right size
content = new Uint8Array(size);
} else {
content = "dummy";
}
return new File([content], name, { type });
}
// ─── Tests ───────────────────────────────────────────────────────────────────
describe("validateImageFile", () => {
it("accepts valid JPEG", () => {
const file = createMockFile({ name: "photo.jpg", type: "image/jpeg", size: 1024 });
const result = validateImageFile(file);
expect(result.ok).toBe(true);
});
it("accepts valid PNG", () => {
const file = createMockFile({ name: "photo.png", type: "image/png", size: 1024 });
const result = validateImageFile(file);
expect(result.ok).toBe(true);
});
it("accepts valid WebP", () => {
const file = createMockFile({ name: "photo.webp", type: "image/webp", size: 1024 });
const result = validateImageFile(file);
expect(result.ok).toBe(true);
});
it("rejects unsupported MIME type", () => {
const file = createMockFile({ name: "document.txt", type: "text/plain", size: 1024 });
const result = validateImageFile(file);
expect(result.ok).toBe(false);
expect(result.error).toContain("Unsupported file type");
});
it("rejects files larger than 10 MB", () => {
const file = createMockFile({
name: "huge.jpg",
type: "image/jpeg",
size: 11 * 1024 * 1024, // 11 MB
});
const result = validateImageFile(file);
expect(result.ok).toBe(false);
expect(result.error).toContain("too large");
});
it("rejects empty files", () => {
const file = createMockFile({ name: "empty.jpg", type: "image/jpeg", size: 0 });
const result = validateImageFile(file);
expect(result.ok).toBe(false);
expect(result.error).toContain("empty");
});
});
describe("imageToTensor", () => {
it("produces correct tensor length for 224×224", () => {
const imageData = createMockImageData(224, 224);
const tensor = imageToTensor(imageData);
expect(tensor.length).toBe(3 * 224 * 224);
});
it("produces correct tensor length for 299×299", () => {
const imageData = createMockImageData(299, 299);
const tensor = imageToTensor(imageData);
expect(tensor.length).toBe(3 * 299 * 299);
});
it("produces Float32Array", () => {
const imageData = createMockImageData(224, 224);
const tensor = imageToTensor(imageData);
expect(tensor).toBeInstanceOf(Float32Array);
});
it("normalizes pixel values with ImageNet mean/std", () => {
// All pixels set to 128/255 = 0.502
const imageData = createMockImageData(224, 224, 128, 128, 128);
const tensor = imageToTensor(imageData);
// With ImageNet mean [0.485, 0.456, 0.406] and std [0.229, 0.224, 0.225]
// R: (0.502 - 0.485) / 0.229 ≈ 0.074
const expectedR = (128 / 255 - 0.485) / 0.229;
expect(tensor[0]).toBeCloseTo(expectedR, 3);
// G: (0.502 - 0.456) / 0.224 ≈ 0.205
const totalPixels = 224 * 224;
const expectedG = (128 / 255 - 0.456) / 0.224;
expect(tensor[totalPixels]).toBeCloseTo(expectedG, 3);
// B: (0.502 - 0.406) / 0.225 ≈ 0.427
const expectedB = (128 / 255 - 0.406) / 0.225;
expect(tensor[2 * totalPixels]).toBeCloseTo(expectedB, 3);
});
it("preserves channel separation (R, G, B in separate channels)", () => {
// R=255, G=0, B=0 (pure red)
const imageData = createMockImageData(224, 224, 255, 0, 0);
const tensor = imageToTensor(imageData);
const totalPixels = 224 * 224;
// After ImageNet normalization:
// R: (1.0 - 0.485) / 0.229 ≈ 2.25
// G: (0.0 - 0.456) / 0.224 ≈ -2.04
// B: (0.0 - 0.406) / 0.225 ≈ -1.80
const rVal = tensor[0];
const gVal = tensor[totalPixels];
const bVal = tensor[2 * totalPixels];
// R is positive (high), G and B are negative (low)
expect(rVal).toBeGreaterThan(2);
expect(gVal).toBeLessThan(-1);
expect(bVal).toBeLessThan(-1);
// R is highest
expect(rVal).toBeGreaterThan(bVal);
expect(rVal).toBeGreaterThan(gVal);
});
});
describe("tensorToBase64 / base64ToTensor", () => {
it("round-trips tensor data correctly", () => {
const imageData = createMockImageData(160, 160, 100, 150, 200);
const original = imageToTensor(imageData);
const base64 = tensorToBase64(original);
const decoded = base64ToTensor(base64);
expect(decoded.tensor.length).toBe(original.length);
expect(decoded.shape).toEqual([3, 160, 160]);
// Check a few values match
for (let i = 0; i < 10; i++) {
expect(decoded.tensor[i]).toBeCloseTo(original[i], 5);
}
});
it("preserves custom shape", () => {
const tensor = new Float32Array(3 * 299 * 299);
const base64 = tensorToBase64(tensor, [3, 299, 299]);
const decoded = base64ToTensor(base64);
expect(decoded.shape).toEqual([3, 299, 299]);
});
});
describe("getTensorShape", () => {
it("returns [1, 3, 160, 160] by default", () => {
const shape = getTensorShape();
expect(shape).toEqual([1, 3, 160, 160]);
});
it("returns NCHW layout", () => {
const shape = getTensorShape();
expect(shape.length).toBe(4);
expect(shape[0]).toBe(1); // batch
expect(shape[1]).toBe(3); // channels
expect(shape[2]).toBe(160); // height (model input size)
expect(shape[3]).toBe(160); // width (model input size)
});
});
describe("resizeImage", () => {
it("is an async function", () => {
expect(typeof resizeImage).toBe("function");
});
it("accepts a File and size parameter", () => {
const file = createMockFile({ name: "test.jpg", type: "image/jpeg", size: 1024 });
// In jsdom, this will use the mock canvas — we verify the function signature
expect(resizeImage).toBeDefined();
});
});
describe("constants", () => {
it("MAX_FILE_SIZE is 10 MB", () => {
expect(MAX_FILE_SIZE).toBe(10 * 1024 * 1024);
});
it("MIN_DIMENSION is 150", () => {
expect(MIN_DIMENSION).toBe(150);
});
it("ALLOWED_MIME_TYPES includes PNG, JPEG, and WebP", () => {
expect(ALLOWED_MIME_TYPES).toContain("image/png");
expect(ALLOWED_MIME_TYPES).toContain("image/jpeg");
expect(ALLOWED_MIME_TYPES).toContain("image/webp");
});
});

244
src/lib/image-processing.ts Normal file
View File

@@ -0,0 +1,244 @@
/**
* Client-side image preprocessing pipeline.
*
* Resizes images to model-expected dimensions (160×160 by default),
* converts RGBA → RGB, normalizes pixel values, and produces flat
* Float32Array tensors ready for ML inference or base64 transmission.
*
* Tensor shape: [1, 3, 160, 160] — NCHW layout matching MobileNetV2.
*
* Configurable via env:
* IMAGE_MODEL_SIZE — target dimension (default 160)
* IMAGE_MEAN_R/G/B — per-channel mean for normalization (default 0.485, 0.456, 0.406 — ImageNet)
* IMAGE_STD_R/G/B — per-channel std for normalization (default 0.229, 0.224, 0.225 — ImageNet)
*/
// ─── Configuration ───────────────────────────────────────────────────────────
const DEFAULT_MODEL_SIZE = 160;
const DEFAULT_MEAN = [0.485, 0.456, 0.406] as const; // ImageNet RGB means
const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds
function getConfig(): {
size: number;
mean: readonly [number, number, number];
std: readonly [number, number, number];
} {
// These env vars are exposed via next.config.ts / .env.local
const size = parseInt(
typeof process !== "undefined" && process.env?.IMAGE_MODEL_SIZE
? process.env.IMAGE_MODEL_SIZE
: String(DEFAULT_MODEL_SIZE),
10,
);
return {
size: isNaN(size) ? DEFAULT_MODEL_SIZE : size,
mean: DEFAULT_MEAN,
std: DEFAULT_STD,
};
}
// ─── Constants ───────────────────────────────────────────────────────────────
/** Maximum file size accepted (10 MB) */
export const MAX_FILE_SIZE = 10 * 1024 * 1024;
/** Minimum image dimensions (150×150) */
export const MIN_DIMENSION = 150;
/** Allowed MIME types */
export const ALLOWED_MIME_TYPES = ["image/png", "image/jpeg", "image/jpg", "image/webp"] as const;
export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number];
/** Maximum number of ephemeral uploads to keep */
export const MAX_UPLOADS = 100;
// ─── Validation ──────────────────────────────────────────────────────────────
/**
* Validate that a file is an acceptable image for upload.
* Returns `{ ok: true }` or `{ ok: false, error: string }`.
*/
export function validateImageFile(file: File): { ok: true } | { ok: false; error: string } {
// MIME type check
if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) {
return {
ok: false,
error: `Unsupported file type "${file.type}". Allowed: PNG, JPG, WebP.`,
};
}
// Size check
if (file.size > MAX_FILE_SIZE) {
const mb = (file.size / (1024 * 1024)).toFixed(1);
return { ok: false, error: `File too large (${mb} MB). Maximum is 10 MB.` };
}
// Zero-size check
if (file.size === 0) {
return { ok: false, error: "File is empty." };
}
return { ok: true };
}
/**
* Check that an image file meets minimum dimension requirements.
* Returns a promise resolving to `{ ok: true }` or `{ ok: false, error: string }`.
*/
export function validateImageDimensions(
file: File,
): Promise<{ ok: true } | { ok: false; error: string }> {
return new Promise((resolve) => {
const img = new Image();
img.onload = () => {
if (img.width < MIN_DIMENSION || img.height < MIN_DIMENSION) {
resolve({
ok: false,
error: `Image too small (${img.width}×${img.height}). Minimum is ${MIN_DIMENSION}×${MIN_DIMENSION}.`,
});
} else {
resolve({ ok: true });
}
};
img.onerror = () => {
resolve({ ok: false, error: "Failed to read image dimensions." });
};
img.src = URL.createObjectURL(file);
});
}
// ─── Resize ──────────────────────────────────────────────────────────────────
/**
* Resize an image file to the target model size using an offscreen canvas.
* Uses bilinear interpolation via canvas drawing.
*
* @param file - Source image file
* @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224.
* @returns ImageData at exactly `size × size`
*/
export async function resizeImage(file: File, size: number = getConfig().size): Promise<ImageData> {
return new Promise((resolve, reject) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement("canvas");
canvas.width = size;
canvas.height = size;
const ctx = canvas.getContext("2d");
if (!ctx) {
reject(new Error("Could not get canvas 2D context."));
return;
}
// Bilinear resize via drawImage
ctx.imageSmoothingEnabled = true;
ctx.imageSmoothingQuality = "high";
ctx.drawImage(img, 0, 0, size, size);
const imageData = ctx.getImageData(0, 0, size, size);
URL.revokeObjectURL(img.src);
resolve(imageData);
};
img.onerror = () => {
URL.revokeObjectURL(img.src);
reject(new Error("Failed to load image for resizing."));
};
img.src = URL.createObjectURL(file);
});
}
// ─── Tensor Conversion ───────────────────────────────────────────────────────
/**
* Convert ImageData (RGBA) to a flat Float32Array tensor in RGB layout.
* Drops the alpha channel, normalizes pixel values to [0, 1].
*
* Output layout: flat array of length 3 × width × height.
* Channel order: RRR...GGG...BBB... (channel-first, like PyTorch NCHW without batch dim).
*
* @param imageData - Source ImageData from resizeImage()
* @returns Float32Array of length 3 × size × size with values in [0, 1]
*/
export function imageToTensor(imageData: ImageData): Float32Array {
const { width, height, data } = imageData;
const totalPixels = width * height;
const config = getConfig();
const { mean, std } = config;
// Allocate channel-first tensor: [3, H, W]
const tensor = new Float32Array(3 * totalPixels);
// Extract R, G, B channels (skip alpha)
const rChannel = tensor.subarray(0, totalPixels);
const gChannel = tensor.subarray(totalPixels, 2 * totalPixels);
const bChannel = tensor.subarray(2 * totalPixels, 3 * totalPixels);
for (let i = 0; i < totalPixels; i++) {
const idx = i * 4; // RGBA stride
rChannel[i] = data[idx] / 255;
gChannel[i] = data[idx + 1] / 255;
bChannel[i] = data[idx + 2] / 255;
}
// Normalize with ImageNet mean/std
for (let c = 0; c < 3; c++) {
const channel = c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
const m = mean[c];
const s = std[c];
for (let i = 0; i < totalPixels; i++) {
channel[i] = (channel[i] - m) / s;
}
}
return tensor;
}
/**
* Get the expected tensor shape for the current model configuration.
* Returns [batch, channels, height, width] = [1, 3, size, size].
*/
export function getTensorShape(): [number, number, number, number] {
const size = getConfig().size;
return [1, 3, size, size];
}
// ─── Base64 Encoding ─────────────────────────────────────────────────────────
/**
* Encode a Float32Array tensor to a base64 string for transmission.
* Wraps the binary data in a simple JSON envelope with shape metadata.
*
* @param tensor - Flat Float32Array from imageToTensor()
* @param shape - Tensor shape [C, H, W], defaults to [3, size, size]
* @returns base64-encoded JSON string
*/
export function tensorToBase64(
tensor: Float32Array,
shape: [number, number, number] = [3, getConfig().size, getConfig().size],
): string {
const envelope = {
shape,
data: Array.from(tensor),
};
const json = JSON.stringify(envelope);
return btoa(json);
}
/**
* Decode a base64 tensor string back to a Float32Array.
*
* @param base64 - Base64 string from tensorToBase64()
* @returns { tensor, shape }
*/
export function base64ToTensor(base64: string): {
tensor: Float32Array;
shape: [number, number, number];
} {
const json = atob(base64);
const envelope = JSON.parse(json);
return {
tensor: new Float32Array(envelope.data),
shape: envelope.shape as [number, number, number],
};
}

0
src/lib/ml/.gitkeep Normal file
View File

View File

@@ -0,0 +1,342 @@
/**
* Unit tests for lib/ml/confidence.ts
*
* Tests:
* - softmax([1, 2, 3]) sums to ~1.0
* - softmaxFloat32 produces same results as softmax
* - calibrateConfidence(0.9) returns label "high"
* - calibrateConfidence(0.6) returns label "medium"
* - calibrateConfidence(0.3) returns label "low"
* - getTopK returns exactly 5 entries sorted descending
* - getTopKFloat32 returns exactly 5 entries sorted descending
* - filterByConfidence removes predictions below threshold
* - Numerically stable softmax handles large logits
* - Degenerate softmax (all -Infinity) returns uniform distribution
*/
import { describe, it, expect } from "vitest";
import {
softmax,
softmaxFloat32,
calibrateConfidence,
getConfidenceLabel,
getTopK,
getTopKFloat32,
filterByConfidence,
DEFAULT_MIN_CONFIDENCE,
} from "@/lib/ml/confidence";
describe("softmax", () => {
it("softmax([1, 2, 3]) sums to ~1.0", () => {
const result = softmax([1, 2, 3]);
const sum = result.reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 6);
});
it("produces correct probability distribution", () => {
const result = softmax([1, 2, 3]);
// Higher input → higher probability
expect(result[2]).toBeGreaterThan(result[1]);
expect(result[1]).toBeGreaterThan(result[0]);
// All positive
expect(result.every(v => v > 0)).toBe(true);
});
it("handles equal logits uniformly", () => {
const result = softmax([1, 1, 1]);
expect(result[0]).toBeCloseTo(1 / 3, 6);
expect(result[1]).toBeCloseTo(1 / 3, 6);
expect(result[2]).toBeCloseTo(1 / 3, 6);
});
it("handles single element", () => {
const result = softmax([5]);
expect(result).toEqual([1.0]);
});
it("handles large logits without overflow", () => {
const result = softmax([1000, 1001, 1002]);
const sum = result.reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 6);
// The largest should dominate
expect(result[2]).toBeGreaterThan(0.5);
});
it("handles negative logits", () => {
const result = softmax([-3, -2, -1]);
const sum = result.reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 6);
expect(result.every(v => v > 0)).toBe(true);
});
});
describe("softmaxFloat32", () => {
it("produces Float32Array output", () => {
const logits = new Float32Array([1, 2, 3]);
const result = softmaxFloat32(logits);
expect(result).toBeInstanceOf(Float32Array);
});
it("sums to ~1.0", () => {
const logits = new Float32Array([1, 2, 3]);
const result = softmaxFloat32(logits);
const sum = Array.from(result).reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 5);
});
it("matches softmax for same input", () => {
const input = [1, 2, 3, 4, 5];
const arrayResult = softmax(input);
const float32Result = softmaxFloat32(new Float32Array(input));
for (let i = 0; i < input.length; i++) {
expect(float32Result[i]).toBeCloseTo(arrayResult[i], 5);
}
});
it("handles large arrays (95 classes)", () => {
const logits = new Float32Array(95);
for (let i = 0; i < 95; i++) {
logits[i] = i * 0.1 - 4.75; // centered around 0
}
const result = softmaxFloat32(logits);
const sum = Array.from(result).reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 5);
expect(result.length).toBe(95);
});
});
describe("calibrateConfidence", () => {
it("calibrateConfidence(0.9) returns label 'high'", () => {
const result = calibrateConfidence(0.9);
expect(result.label).toBe("high");
expect(result.raw).toBe(0.9);
expect(result.adjusted).toBeGreaterThan(0.8);
});
it("calibrateConfidence(0.95) returns label 'high'", () => {
const result = calibrateConfidence(0.95);
expect(result.label).toBe("high");
});
it("calibrateConfidence(0.8) returns label 'high'", () => {
const result = calibrateConfidence(0.8);
expect(result.label).toBe("high");
});
it("calibrateConfidence(0.6) returns label 'medium'", () => {
const result = calibrateConfidence(0.6);
expect(result.label).toBe("medium");
expect(result.adjusted).toBeGreaterThanOrEqual(0.5);
expect(result.adjusted).toBeLessThan(0.8);
});
it("calibrateConfidence(0.55) returns label 'medium'", () => {
const result = calibrateConfidence(0.55);
expect(result.label).toBe("medium");
});
it("calibrateConfidence(0.3) returns label 'low'", () => {
const result = calibrateConfidence(0.3);
expect(result.label).toBe("low");
expect(result.adjusted).toBeLessThan(0.5);
});
it("calibrateConfidence(0.1) returns label 'low'", () => {
const result = calibrateConfidence(0.1);
expect(result.label).toBe("low");
});
it("calibrateConfidence(0.0) returns label 'low'", () => {
const result = calibrateConfidence(0.0);
expect(result.label).toBe("low");
});
it("calibrateConfidence(1.0) returns label 'high'", () => {
const result = calibrateConfidence(1.0);
expect(result.label).toBe("high");
expect(result.adjusted).toBeGreaterThan(0.9);
});
it("adjusted confidence is rounded to 4 decimal places", () => {
const result = calibrateConfidence(0.73);
const decimalPlaces = result.adjusted.toString().split(".")[1]?.length || 0;
expect(decimalPlaces).toBeLessThanOrEqual(4);
});
it("raw confidence is rounded to 4 decimal places", () => {
const result = calibrateConfidence(0.73456789);
expect(result.raw).toBe(0.7346);
});
it("adjusted confidence is monotonically increasing with raw", () => {
const low = calibrateConfidence(0.3);
const mid = calibrateConfidence(0.6);
const high = calibrateConfidence(0.9);
expect(high.adjusted).toBeGreaterThan(mid.adjusted);
expect(mid.adjusted).toBeGreaterThan(low.adjusted);
});
});
describe("getConfidenceLabel", () => {
it("returns 'high' for score >= 0.8", () => {
expect(getConfidenceLabel(0.8)).toBe("high");
expect(getConfidenceLabel(0.85)).toBe("high");
expect(getConfidenceLabel(1.0)).toBe("high");
});
it("returns 'medium' for score >= 0.5 and < 0.8", () => {
expect(getConfidenceLabel(0.5)).toBe("medium");
expect(getConfidenceLabel(0.65)).toBe("medium");
expect(getConfidenceLabel(0.79)).toBe("medium");
});
it("returns 'low' for score < 0.5", () => {
expect(getConfidenceLabel(0.0)).toBe("low");
expect(getConfidenceLabel(0.49)).toBe("low");
});
});
describe("getTopK", () => {
it("returns exactly 5 entries by default", () => {
const probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
const result = getTopK(probs);
expect(result).toHaveLength(5);
});
it("returns entries sorted by probability descending", () => {
const probs = [0.1, 0.5, 0.3, 0.9, 0.2, 0.7, 0.4];
const result = getTopK(probs);
for (let i = 0; i < result.length - 1; i++) {
expect(result[i].probability).toBeGreaterThanOrEqual(result[i + 1].probability);
}
});
it("returns correct class indices", () => {
const probs = [0.1, 0.5, 0.3, 0.9, 0.2];
const result = getTopK(probs, 3);
expect(result[0].classIndex).toBe(3); // 0.9
expect(result[1].classIndex).toBe(1); // 0.5
expect(result[2].classIndex).toBe(2); // 0.3
});
it("respects custom k value", () => {
const probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
const result = getTopK(probs, 3);
expect(result).toHaveLength(3);
});
it("returns all entries when k > array length", () => {
const probs = [0.1, 0.2, 0.3];
const result = getTopK(probs, 10);
expect(result).toHaveLength(3);
});
it("handles equal probabilities", () => {
const probs = [0.3, 0.3, 0.3, 0.1, 0.1];
const result = getTopK(probs, 3);
expect(result).toHaveLength(3);
expect(result.every(p => p.probability === 0.3)).toBe(true);
});
});
describe("getTopKFloat32", () => {
it("returns exactly 5 entries by default", () => {
const probs = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]);
const result = getTopKFloat32(probs);
expect(result).toHaveLength(5);
});
it("returns entries sorted by probability descending", () => {
const probs = new Float32Array([0.1, 0.5, 0.3, 0.9, 0.2, 0.7, 0.4]);
const result = getTopKFloat32(probs);
for (let i = 0; i < result.length - 1; i++) {
expect(result[i].probability).toBeGreaterThanOrEqual(result[i + 1].probability);
}
});
it("returns correct class indices", () => {
const probs = new Float32Array([0.1, 0.5, 0.3, 0.9, 0.2]);
const result = getTopKFloat32(probs, 3);
expect(result[0].classIndex).toBe(3);
expect(result[1].classIndex).toBe(1);
expect(result[2].classIndex).toBe(2);
});
it("handles large arrays (95 classes)", () => {
const probs = new Float32Array(95);
// Set a few high values
probs[0] = 0.4;
probs[5] = 0.3;
probs[10] = 0.2;
probs[20] = 0.05;
probs[30] = 0.03;
// Rest are small
for (let i = 0; i < 95; i++) {
if (probs[i] === 0) probs[i] = 0.001;
}
const result = getTopKFloat32(probs, 5);
expect(result).toHaveLength(5);
expect(result[0].classIndex).toBe(0);
expect(result[0].probability).toBeCloseTo(0.4, 5);
});
});
describe("filterByConfidence", () => {
it("removes predictions below default threshold", () => {
const predictions = [
{ classIndex: 0, probability: 0.5 },
{ classIndex: 1, probability: 0.3 },
{ classIndex: 2, probability: 0.1 },
{ classIndex: 3, probability: 0.05 },
];
const result = filterByConfidence(predictions);
expect(result).toHaveLength(2);
expect(result[0].classIndex).toBe(0);
expect(result[1].classIndex).toBe(1);
});
it("uses custom threshold", () => {
const predictions = [
{ classIndex: 0, probability: 0.5 },
{ classIndex: 1, probability: 0.3 },
{ classIndex: 2, probability: 0.1 },
];
const result = filterByConfidence(predictions, 0.25);
expect(result).toHaveLength(2);
});
it("returns empty array when all below threshold", () => {
const predictions = [
{ classIndex: 0, probability: 0.1 },
{ classIndex: 1, probability: 0.05 },
];
const result = filterByConfidence(predictions, 0.2);
expect(result).toEqual([]);
});
it("returns all predictions when all above threshold", () => {
const predictions = [
{ classIndex: 0, probability: 0.5 },
{ classIndex: 1, probability: 0.3 },
];
const result = filterByConfidence(predictions, 0.1);
expect(result).toHaveLength(2);
});
it("keeps predictions at exactly the threshold", () => {
const predictions = [
{ classIndex: 0, probability: 0.15 },
{ classIndex: 1, probability: 0.14 },
];
const result = filterByConfidence(predictions, 0.15);
expect(result).toHaveLength(1);
expect(result[0].classIndex).toBe(0);
});
});
describe("DEFAULT_MIN_CONFIDENCE", () => {
it("is 0.15", () => {
expect(DEFAULT_MIN_CONFIDENCE).toBe(0.15);
});
});

204
src/lib/ml/confidence.ts Normal file
View File

@@ -0,0 +1,204 @@
/**
* Confidence calibration and threshold logic for ML predictions.
*
* Provides softmax conversion, confidence calibration, and threshold-based
* filtering of predictions.
*/
import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types";
// ─── Configuration ───────────────────────────────────────────────────────────
/** Minimum confidence threshold — predictions below this are filtered out */
export const DEFAULT_MIN_CONFIDENCE = 0.15;
/** Confidence label thresholds */
const CONFIDENCE_THRESHOLDS = {
HIGH: 0.8,
MEDIUM: 0.5,
} as const;
// ─── Softmax ─────────────────────────────────────────────────────────────────
/**
* Apply softmax to a vector of logits, converting them to probabilities.
*
* Uses numerically stable softmax: subtracts max before exp() to avoid overflow.
*
* @param logits - Array of raw model output values
* @returns Array of probabilities that sum to ~1.0
*/
export function softmax(logits: number[]): number[] {
const maxLogit = Math.max(...logits);
const expValues = logits.map((l) => Math.exp(l - maxLogit));
const sumExp = expValues.reduce((a, b) => a + b, 0);
if (sumExp === 0) {
// Degenerate case: all logits are -Infinity
const uniform = 1 / logits.length;
return logits.map(() => uniform);
}
return expValues.map((e) => e / sumExp);
}
/**
* Apply softmax to a Float32Array of logits.
*
* @param logits - Float32Array of raw model output values
* @returns Float32Array of probabilities that sum to ~1.0
*/
export function softmaxFloat32(logits: Float32Array): Float32Array {
const maxLogit = -Infinity;
let actualMax = maxLogit;
for (let i = 0; i < logits.length; i++) {
if (logits[i] > actualMax) actualMax = logits[i];
}
const expValues = new Float32Array(logits.length);
let sumExp = 0;
for (let i = 0; i < logits.length; i++) {
expValues[i] = Math.exp(logits[i] - actualMax);
sumExp += expValues[i];
}
if (sumExp === 0) {
const uniform = 1 / logits.length;
return new Float32Array(logits.length).fill(uniform);
}
for (let i = 0; i < expValues.length; i++) {
expValues[i] /= sumExp;
}
return expValues;
}
// ─── Confidence Calibration ──────────────────────────────────────────────────
/**
* Calibrate a raw probability into an adjusted confidence score with a label.
*
* Applies a mild calibration that slightly adjusts raw softmax probabilities
* to account for model overconfidence. Uses a linear calibration:
* adjusted = rawProb * calibrationFactor
* where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost
* well-separated predictions while keeping the value in [0, 1].
*
* The calibrated value is clamped to [0, 1] and labeled using thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param rawProb - Raw softmax probability (01)
* @param calibrationFactor - Linear calibration factor (default 1.02)
* @returns { adjusted, label }
*/
export function calibrateConfidence(
rawProb: number,
calibrationFactor = 1.02,
): ConfidenceResult {
const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor));
const label = getConfidenceLabel(adjusted);
return {
raw: roundToDecimals(rawProb, 4),
adjusted: roundToDecimals(adjusted, 4),
label,
};
}
/**
* Get the confidence label for a given score.
*
* Thresholds:
* high ≥ 0.8
* medium ≥ 0.5
* low < 0.5
*
* @param score - Confidence score (01)
* @returns Confidence label
*/
export function getConfidenceLabel(score: number): ConfidenceLabel {
if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high";
if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium";
return "low";
}
/**
* Apply sigmoid function: 1 / (1 + exp(-x))
*/
function sigmoid(x: number): number {
return 1 / (1 + Math.exp(-x));
}
/**
* Round a number to a given number of decimal places.
*/
function roundToDecimals(value: number, decimals: number): number {
const factor = Math.pow(10, decimals);
return Math.round(value * factor) / factor;
}
// ─── Top-K Extraction ────────────────────────────────────────────────────────
/**
* Extract the top-K predictions from a probability array.
*
* @param probabilities - Array of probabilities (from softmax)
* @param k - Number of top predictions to return (default 5)
* @returns Array of { classIndex, probability } sorted by probability descending
*/
export function getTopK(
probabilities: number[],
k = 5,
): RawPrediction[] {
// Create indexed pairs
const indexed = probabilities.map((prob, index) => ({
classIndex: index,
probability: prob,
}));
// Sort by probability descending
indexed.sort((a, b) => b.probability - a.probability);
// Take top K
return indexed.slice(0, k);
}
/**
* Extract top-K predictions from a Float32Array of probabilities.
*
* @param probabilities - Float32Array of probabilities
* @param k - Number of top predictions (default 5)
* @returns Array of { classIndex, probability } sorted descending
*/
export function getTopKFloat32(
probabilities: Float32Array,
k = 5,
): RawPrediction[] {
const indexed: Array<{ classIndex: number; probability: number }> = [];
for (let i = 0; i < probabilities.length; i++) {
indexed.push({ classIndex: i, probability: probabilities[i] });
}
indexed.sort((a, b) => b.probability - a.probability);
return indexed.slice(0, k);
}
// ─── Filtering ───────────────────────────────────────────────────────────────
/**
* Filter predictions by minimum confidence threshold.
*
* @param predictions - Raw predictions from getTopK()
* @param minConfidence - Minimum probability threshold (default 0.15)
* @returns Filtered predictions array
*/
export function filterByConfidence(
predictions: RawPrediction[],
minConfidence = DEFAULT_MIN_CONFIDENCE,
): RawPrediction[] {
return predictions.filter((p) => p.probability >= minConfidence);
}

View File

@@ -0,0 +1,244 @@
/**
* Unit tests for lib/ml/inference.ts
*
* Tests:
* - validateInput rejects non-Float32Array
* - validateInput rejects wrong-length arrays
* - validateInput rejects NaN/Infinity values
* - validateInput accepts correct tensor
* - createZeroTensor produces correct shape
* - createRandomTensor produces correct shape with finite values
* - runInference returns InferenceResult with predictions array
* - runInference returns exactly top-K predictions
* - runInference predictions are sorted descending
* - runInference includes inferenceTimeMs
* - runInference completes under 3 seconds
* - runBatchInference processes multiple images
*/
import { describe, it, expect, beforeEach } from "vitest";
import {
runInference,
validateInput,
createZeroTensor,
createRandomTensor,
runBatchInference,
INPUT_SIZE,
INPUT_SHAPE,
DEFAULT_TOP_K,
} from "@/lib/ml/inference";
import { resetModelCache } from "@/lib/ml/model-loader";
describe("validateInput", () => {
it("rejects non-Float32Array", () => {
expect(() => validateInput([1, 2, 3] as any)).toThrow("Expected Float32Array input");
});
it("rejects wrong-length arrays", () => {
const tensor = new Float32Array(100);
expect(() => validateInput(tensor)).toThrow(`Expected tensor of length ${INPUT_SIZE}`);
});
it("rejects NaN values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = NaN;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("rejects Infinity values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = Infinity;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("rejects -Infinity values", () => {
const tensor = new Float32Array(INPUT_SIZE);
tensor[50] = -Infinity;
expect(() => validateInput(tensor)).toThrow("non-finite value");
});
it("accepts correct tensor", () => {
const tensor = createZeroTensor();
expect(() => validateInput(tensor)).not.toThrow();
});
it("accepts tensor with negative values", () => {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < INPUT_SIZE; i++) {
tensor[i] = -2;
}
expect(() => validateInput(tensor)).not.toThrow();
});
it("accepts tensor with values near zero", () => {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < INPUT_SIZE; i++) {
tensor[i] = 0.0001;
}
expect(() => validateInput(tensor)).not.toThrow();
});
});
describe("createZeroTensor", () => {
it("produces Float32Array", () => {
const tensor = createZeroTensor();
expect(tensor).toBeInstanceOf(Float32Array);
});
it("has correct length", () => {
const tensor = createZeroTensor();
expect(tensor.length).toBe(INPUT_SIZE);
});
it("has correct shape dimensions", () => {
const expectedLength = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3];
expect(INPUT_SIZE).toBe(expectedLength);
});
it("all values are zero", () => {
const tensor = createZeroTensor();
expect(tensor.every((v) => v === 0)).toBe(true);
});
});
describe("createRandomTensor", () => {
it("produces Float32Array", () => {
const tensor = createRandomTensor();
expect(tensor).toBeInstanceOf(Float32Array);
});
it("has correct length", () => {
const tensor = createRandomTensor();
expect(tensor.length).toBe(INPUT_SIZE);
});
it("all values are finite", () => {
const tensor = createRandomTensor();
expect(tensor.every((v) => Number.isFinite(v))).toBe(true);
});
it("produces varied values", () => {
const tensor = createRandomTensor();
const uniqueValues = new Set(tensor.map((v) => v.toFixed(4)));
expect(uniqueValues.size).toBeGreaterThan(100);
});
it("passes validateInput", () => {
const tensor = createRandomTensor();
expect(() => validateInput(tensor)).not.toThrow();
});
});
describe("INPUT_SHAPE and INPUT_SIZE", () => {
it("INPUT_SHAPE is [1, 3, 160, 160]", () => {
expect(INPUT_SHAPE).toEqual([1, 3, 160, 160]);
});
it("INPUT_SIZE equals 3 * 160 * 160", () => {
expect(INPUT_SIZE).toBe(3 * 160 * 160);
});
it("DEFAULT_TOP_K is 5", () => {
expect(DEFAULT_TOP_K).toBe(5);
});
});
describe("runInference", () => {
beforeEach(() => {
resetModelCache();
});
it("returns InferenceResult with predictions array", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.predictions).toBeDefined();
expect(Array.isArray(result.predictions)).toBe(true);
}, 10000);
it("returns exactly top-K predictions by default", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
}, 10000);
it("returns custom top-K predictions", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor, 3);
expect(result.predictions.length).toBe(3);
}, 10000);
it("predictions are sorted by probability descending", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
for (let i = 0; i < result.predictions.length - 1; i++) {
expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
result.predictions[i + 1].probability,
);
}
}, 10000);
it("includes inferenceTimeMs", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
expect(result.inferenceTimeMs).toBeDefined();
expect(typeof result.inferenceTimeMs).toBe("number");
expect(result.inferenceTimeMs).toBeGreaterThan(0);
}, 10000);
it("completes under 3 seconds", async () => {
const tensor = createRandomTensor();
const start = performance.now();
const result = await runInference(tensor);
const elapsed = performance.now() - start;
expect(elapsed).toBeLessThan(3000);
expect(result.inferenceTimeMs).toBeLessThan(3000);
}, 10000);
it("each prediction has classIndex and probability", async () => {
const tensor = createRandomTensor();
const result = await runInference(tensor);
for (const pred of result.predictions) {
expect(pred.classIndex).toBeDefined();
expect(typeof pred.classIndex).toBe("number");
expect(pred.probability).toBeDefined();
expect(typeof pred.probability).toBe("number");
expect(pred.probability).toBeGreaterThanOrEqual(0);
expect(pred.probability).toBeLessThanOrEqual(1);
}
}, 10000);
it("throws on invalid input", async () => {
const badTensor = new Float32Array(100);
await expect(runInference(badTensor)).rejects.toThrow();
});
});
describe("runBatchInference", () => {
beforeEach(() => {
resetModelCache();
});
it("processes multiple images", async () => {
const tensors = [createRandomTensor(), createRandomTensor(), createRandomTensor()];
const results = await runBatchInference(tensors);
expect(results).toHaveLength(3);
for (const result of results) {
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
expect(result.inferenceTimeMs).toBeGreaterThan(0);
}
}, 30000);
it("each result is independent", async () => {
const tensors = [createRandomTensor(), createRandomTensor()];
const results = await runBatchInference(tensors);
// Results should differ (different random inputs → different predictions)
expect(results[0].predictions[0].classIndex).toBeDefined();
expect(results[1].predictions[0].classIndex).toBeDefined();
}, 15000);
it("accepts custom top-K", async () => {
const tensors = [createRandomTensor()];
const results = await runBatchInference(tensors, 3);
expect(results[0].predictions.length).toBe(3);
}, 15000);
});

137
src/lib/ml/inference.ts Normal file
View File

@@ -0,0 +1,137 @@
/**
* ML inference pipeline for plant disease classification.
*
* Accepts a preprocessed image tensor, runs it through the model,
* applies softmax, extracts top-K predictions, and returns results
* with timing metadata.
*/
import type { InferenceResult, RawPrediction } from "@/lib/types";
import { getModel } from "./model-loader";
import { softmaxFloat32, getTopKFloat32 } from "./confidence";
// ─── Configuration ───────────────────────────────────────────────────────────
/** Number of top predictions to return */
export const DEFAULT_TOP_K = 5;
/** Input tensor shape: [batch=1, channels=3, height=160, width=160] */
export const INPUT_SHAPE: [number, number, number, number] = [1, 3, 160, 160];
/** Expected input tensor length */
export const INPUT_SIZE = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3]; // 3 * 160 * 160 = 76800
// ─── Main Inference ──────────────────────────────────────────────────────────
/**
* Run the full inference pipeline on a preprocessed image tensor.
*
* @param imageTensor - Normalized Float32Array of shape [1, 3, 160, 160] (NCHW)
* @param topK - Number of top predictions to return (default 5)
* @returns InferenceResult with top-K predictions and timing
*/
export async function runInference(
imageTensor: Float32Array,
topK = DEFAULT_TOP_K,
): Promise<InferenceResult> {
const startTime = performance.now();
// Validate input
validateInput(imageTensor);
// Get model (lazy loads on first call)
const model = await getModel();
// Run model forward pass
const { logits, inferenceTimeMs } = await model.predict(imageTensor);
// Apply softmax to convert logits to probabilities
const probabilities = softmaxFloat32(logits);
// Extract top-K predictions
const predictions = getTopKFloat32(probabilities, topK);
const totalTime = performance.now() - startTime;
return {
predictions,
inferenceTimeMs: Math.round(totalTime),
};
}
// ─── Input Validation ────────────────────────────────────────────────────────
/**
* Validate that the input tensor has the expected shape and type.
*
* @param tensor - Input tensor to validate
* @throws Error if tensor is invalid
*/
export function validateInput(tensor: Float32Array): void {
if (!(tensor instanceof Float32Array)) {
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
}
if (tensor.length !== INPUT_SIZE) {
throw new Error(
`Expected tensor of length ${INPUT_SIZE} (shape ${INPUT_SHAPE.join("×")}), ` +
`got ${tensor.length}`,
);
}
// Check for NaN/Infinity values
for (let i = 0; i < tensor.length; i++) {
if (!Number.isFinite(tensor[i])) {
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
}
}
}
// ─── Batch Inference ─────────────────────────────────────────────────────────
/**
* Run inference on multiple images.
*
* Currently runs sequentially. For true batching, the model itself would need
* to support batch input.
*
* @param tensors - Array of preprocessed image tensors
* @param topK - Number of top predictions per image
* @returns Array of inference results
*/
export async function runBatchInference(
tensors: Float32Array[],
topK = DEFAULT_TOP_K,
): Promise<InferenceResult[]> {
const results: InferenceResult[] = [];
for (const tensor of tensors) {
results.push(await runInference(tensor, topK));
}
return results;
}
// ─── Utility ─────────────────────────────────────────────────────────────────
/**
* Create a zero-filled input tensor for testing.
*
* @returns Float32Array of shape [1, 3, 224, 224]
*/
export function createZeroTensor(): Float32Array {
return new Float32Array(INPUT_SIZE);
}
/**
* Create a random input tensor for testing.
*
* @returns Float32Array of shape [1, 3, 224, 224] with random values
*/
export function createRandomTensor(): Float32Array {
const tensor = new Float32Array(INPUT_SIZE);
for (let i = 0; i < tensor.length; i++) {
tensor[i] = (Math.random() * 2 - 1) * 2; // Range roughly -2 to 2
}
return tensor;
}

199
src/lib/ml/labels.test.ts Normal file
View File

@@ -0,0 +1,199 @@
/**
* Unit tests for lib/ml/labels.ts
*
* The model has 38 PlantVillage classes. Some map to the app's
* knowledge base disease IDs, others map to "unknown".
*
* Known mappings:
* - indices 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37 → "healthy"
* - index 20 (Potato___Early_blight) → "early-blight"
* - index 21 (Potato___Late_blight) → "late-blight"
* - index 25 (Squash___Powdery_mildew) → "squash-powdery-mildew"
* - index 26 (Strawberry___Leaf_scorch) → "strawberry-leaf-scorch"
* - index 28 (Tomato___Bacterial_spot) → "bacterial-leaf-spot-tomato"
* - index 29 (Tomato___Early_blight) → "early-blight" (duplicate)
* - index 30 (Tomato___Late_blight) → "late-blight" (duplicate)
* - index 32 (Tomato___Septoria_leaf_spot) → "septoria-leaf-spot"
* - index 37 (Tomato___healthy) → "healthy"
* - all others → "unknown"
*/
import { describe, it, expect } from "vitest";
import {
INDEX_TO_DISEASE_ID,
DISEASE_ID_TO_INDEX,
getDiseaseIdForIndex,
getIndexForDiseaseId,
isRealDisease,
getAllDiseaseIds,
NUM_CLASSES,
getPlantVillageClassName,
} from "@/lib/ml/labels";
describe("Constants", () => {
it("NUM_CLASSES is 38 (PlantVillage)", () => {
expect(NUM_CLASSES).toBe(38);
});
it("all 38 indices are mapped", () => {
const keys = Object.keys(INDEX_TO_DISEASE_ID).map(Number);
expect(keys.length).toBe(38);
for (let i = 0; i < 38; i++) {
expect(keys).toContain(i);
}
});
});
describe("INDEX_TO_DISEASE_ID — healthy indices", () => {
const healthyIndices = [3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37];
for (const idx of healthyIndices) {
it(`index ${idx} maps to "healthy"`, () => {
expect(INDEX_TO_DISEASE_ID[idx]).toBe("healthy");
});
}
});
describe("INDEX_TO_DISEASE_ID — known disease mappings", () => {
const cases: Array<{ index: number; expected: string; name: string }> = [
{ index: 20, expected: "early-blight", name: "Potato___Early_blight" },
{ index: 21, expected: "late-blight", name: "Potato___Late_blight" },
{ index: 25, expected: "squash-powdery-mildew", name: "Squash___Powdery_mildew" },
{ index: 26, expected: "strawberry-leaf-scorch", name: "Strawberry___Leaf_scorch" },
{ index: 28, expected: "bacterial-leaf-spot-tomato", name: "Tomato___Bacterial_spot" },
{ index: 29, expected: "early-blight", name: "Tomato___Early_blight" },
{ index: 30, expected: "late-blight", name: "Tomato___Late_blight" },
{ index: 32, expected: "septoria-leaf-spot", name: "Tomato___Septoria_leaf_spot" },
];
for (const { index, expected, name } of cases) {
it(`index ${index} (${name}) maps to "${expected}"`, () => {
expect(INDEX_TO_DISEASE_ID[index]).toBe(expected);
});
}
});
describe("INDEX_TO_DISEASE_ID — unknown (unmapped) indices", () => {
const unknownIndices = [0, 1, 2, 5, 7, 8, 9, 11, 12, 13, 15, 16, 18, 31, 33, 34, 35, 36];
for (const idx of unknownIndices) {
it(`index ${idx} maps to "unknown"`, () => {
expect(INDEX_TO_DISEASE_ID[idx]).toBe("unknown");
});
}
});
describe("DISEASE_ID_TO_INDEX", () => {
it("maps 'early-blight' to first occurrence (index 20)", () => {
expect(DISEASE_ID_TO_INDEX["early-blight"]).toBe(20);
});
it("maps 'late-blight' to first occurrence (index 21)", () => {
expect(DISEASE_ID_TO_INDEX["late-blight"]).toBe(21);
});
it("maps 'septoria-leaf-spot' to index 32", () => {
expect(DISEASE_ID_TO_INDEX["septoria-leaf-spot"]).toBe(32);
});
it("maps 'healthy' to index 3 (first healthy index)", () => {
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(3);
});
});
describe("Bidirectional mapping", () => {
it("every index round-trips correctly", () => {
for (let i = 0; i < NUM_CLASSES; i++) {
const id = INDEX_TO_DISEASE_ID[i];
const idx = DISEASE_ID_TO_INDEX[id];
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
}
});
});
describe("getDiseaseIdForIndex", () => {
it("returns 'unknown' for out-of-range positive index", () => {
expect(getDiseaseIdForIndex(100)).toBe("unknown");
});
it("returns 'unknown' for negative index", () => {
expect(getDiseaseIdForIndex(-1)).toBe("unknown");
});
it("returns correct ID for valid index", () => {
expect(getDiseaseIdForIndex(20)).toBe("early-blight");
});
});
describe("getIndexForDiseaseId", () => {
it("returns -1 for unknown disease ID", () => {
expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1);
});
it("returns -1 for empty string", () => {
expect(getIndexForDiseaseId("")).toBe(-1);
});
it("is case-insensitive", () => {
expect(getIndexForDiseaseId("EARLY-BLIGHT")).toBe(20);
});
});
describe("isRealDisease", () => {
it("returns false for 'healthy'", () => {
expect(isRealDisease("healthy")).toBe(false);
});
it("returns false for 'unknown'", () => {
expect(isRealDisease("unknown")).toBe(false);
});
it("returns true for known disease IDs", () => {
expect(isRealDisease("early-blight")).toBe(true);
expect(isRealDisease("septoria-leaf-spot")).toBe(true);
});
it("returns true for arbitrary non-special strings", () => {
expect(isRealDisease("some-disease")).toBe(true);
});
});
describe("getPlantVillageClassName", () => {
it("returns correct class name for tomato healthy", () => {
expect(getPlantVillageClassName(37)).toBe("Tomato___healthy");
});
it("returns correct class name for potato early blight", () => {
expect(getPlantVillageClassName(20)).toBe("Potato___Early_blight");
});
it("returns 'unknown' for out-of-range index", () => {
expect(getPlantVillageClassName(100)).toBe("unknown");
});
});
describe("getAllDiseaseIds", () => {
it("returns only mapped disease IDs", () => {
const ids = getAllDiseaseIds();
expect(ids).toContain("early-blight");
expect(ids).toContain("late-blight");
expect(ids).toContain("squash-powdery-mildew");
expect(ids).toContain("strawberry-leaf-scorch");
expect(ids).toContain("bacterial-leaf-spot-tomato");
expect(ids).toContain("septoria-leaf-spot");
});
it("excludes 'healthy'", () => {
expect(getAllDiseaseIds()).not.toContain("healthy");
});
it("excludes 'unknown'", () => {
expect(getAllDiseaseIds()).not.toContain("unknown");
});
it("has no duplicates", () => {
const ids = getAllDiseaseIds();
const uniqueIds = new Set(ids);
expect(uniqueIds.size).toBe(ids.length);
});
});

237
src/lib/ml/labels.ts Normal file
View File

@@ -0,0 +1,237 @@
/**
* Class label mapping for the plant disease classifier model.
*
* This model is a MobileNetV2 trained on the PlantVillage dataset
* with 38 classes (14 crops × diseases/healthy).
*
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 38
*
* Index layout (from labels_pv_original.json):
* 0 → Apple___Apple_scab
* 1 → Apple___Black_rot
* 2 → Apple___Cedar_apple_rust
* 3 → Apple___healthy
* 4 → Blueberry___healthy
* 5 → Cherry_(including_sour)___Powdery_mildew
* 6 → Cherry_(including_sour)___healthy
* 7 → Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
* 8 → Corn_(maize)___Common_rust_
* 9 → Corn_(maize)___Northern_Leaf_Blight
* 10 → Corn_(maize)___healthy
* 11 → Grape___Black_rot
* 12 → Grape___Esca_(Black_Measles)
* 13 → Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
* 14 → Grape___healthy
* 15 → Orange___Haunglongbing_(Citrus_greening)
* 16 → Peach___Bacterial_spot
* 17 → Peach___healthy
* 18 → Pepper,_bell___Bacterial_spot
* 19 → Pepper,_bell___healthy
* 20 → Potato___Early_blight
* 21 → Potato___Late_blight
* 22 → Potato___healthy
* 23 → Raspberry___healthy
* 24 → Soybean___healthy
* 25 → Squash___Powdery_mildew
* 26 → Strawberry___Leaf_scorch
* 27 → Strawberry___healthy
* 28 → Tomato___Bacterial_spot
* 29 → Tomato___Early_blight
* 30 → Tomato___Late_blight
* 31 → Tomato___Leaf_Mold
* 32 → Tomato___Septoria_leaf_spot
* 33 → Tomato___Spider_mites Two-spotted_spider_mite
* 34 → Tomato___Target_Spot
* 35 → Tomato___Tomato_Yellow_Leaf_Curl_Virus
* 36 → Tomato___Tomato_mosaic_virus
* 37 → Tomato___healthy
*
* Some PlantVillage classes overlap with this app's knowledge base.
* Exact class name → disease ID mappings:
* Potato___Early_blight → "early-blight"
* Potato___Late_blight → "late-blight"
* Squash___Powdery_mildew → "squash-powdery-mildew"
* Strawberry___Leaf_scorch → "strawberry-leaf-scorch"
* Tomato___Bacterial_spot → "bacterial-leaf-spot-tomato"
* Tomato___Early_blight → "early-blight"
* Tomato___Late_blight → "late-blight"
* Tomato___Septoria_leaf_spot → "septoria-leaf-spot"
* All other classes map to "unknown" and are filtered out during enrichment.
*
* After fine-tuning to the app's 93 disease classes, this file will be
* rewritten to match the new model's output layer.
*/
// ─── PlantVillage class names (in model output order) ────────────────────
const PLANTVILLAGE_CLASSES: string[] = [
"Apple___Apple_scab",
"Apple___Black_rot",
"Apple___Cedar_apple_rust",
"Apple___healthy",
"Blueberry___healthy",
"Cherry_(including_sour)___Powdery_mildew",
"Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
"Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___healthy",
"Grape___Black_rot",
"Grape___Esca_(Black_Measles)",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
"Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot",
"Peach___healthy",
"Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy",
"Potato___Early_blight",
"Potato___Late_blight",
"Potato___healthy",
"Raspberry___healthy",
"Soybean___healthy",
"Squash___Powdery_mildew",
"Strawberry___Leaf_scorch",
"Strawberry___healthy",
"Tomato___Bacterial_spot",
"Tomato___Early_blight",
"Tomato___Late_blight",
"Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus",
"Tomato___healthy",
] as const;
// ─── PlantVillage → App disease ID mapping ──────────────────────────────
/**
* Maps PlantVillage class names (in the form "Plant___Disease") to
* this app's disease IDs. Unmapped classes resolve to "unknown".
*/
function plantVillageNameToDiseaseId(pvName: string): string {
const parts = pvName.split("___");
if (parts.length !== 2) {
return "unknown";
}
const disease = parts[1];
// Detect "healthy" variants
if (disease === "healthy") {
return "healthy";
}
// Map exact PlantVillage class names to our disease IDs.
// Only map classes where we're confident the correspondence holds.
const exactMap: Record<string, string> = {
Squash___Powdery_mildew: "squash-powdery-mildew",
Strawberry___Leaf_scorch: "strawberry-leaf-scorch",
Potato___Early_blight: "early-blight",
Potato___Late_blight: "late-blight",
Tomato___Bacterial_spot: "bacterial-leaf-spot-tomato",
Tomato___Early_blight: "early-blight",
Tomato___Late_blight: "late-blight",
Tomato___Septoria_leaf_spot: "septoria-leaf-spot",
};
return exactMap[pvName] ?? "unknown";
}
// ─── Constants ──────────────────────────────────────────────────────────
/** Total number of model output classes */
export const NUM_CLASSES = PLANTVILLAGE_CLASSES.length; // 38
/** Index for the "healthy" class — multiple PV indices map to this */
export const HEALTHY_INDEX = 0; // First PV healthy class, others also map to this string
/** First disease index (unused in PV mapping, kept for compatibility) */
export const FIRST_DISEASE_INDEX = 0;
/** Index for the "unknown" catch-all — PV classes we can't map */
export const UNKNOWN_INDEX = NUM_CLASSES - 1; // 37 (Tomato___healthy maps to "healthy", not unknown)
// ─── Index → Disease ID mapping ─────────────────────────────────────────
/**
* Map from model output index to app disease ID string.
* Built dynamically from PlantVillage class names.
*/
export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze(
(() => {
const map: Record<number, string> = {};
for (let i = 0; i < NUM_CLASSES; i++) {
map[i] = plantVillageNameToDiseaseId(PLANTVILLAGE_CLASSES[i]);
}
return map;
})(),
);
// ─── Disease ID → Index mapping ─────────────────────────────────────────
/**
* Map from disease ID string to model output index.
* For duplicates (e.g., both potato and tomato "Early_blight" → "early-blight"),
* returns the first matching index.
*/
export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze(
(() => {
const map: Record<string, number> = {};
for (let i = 0; i < NUM_CLASSES; i++) {
const id = INDEX_TO_DISEASE_ID[i];
// First occurrence wins (potato before tomato for early/late blight)
if (map[id] === undefined) {
map[id] = i;
}
}
return map;
})(),
);
// ─── Lookup helpers ─────────────────────────────────────────────────────
/**
* Get the disease ID for a given model output index.
* Returns "unknown" for out-of-range indices.
*/
export function getDiseaseIdForIndex(index: number): string {
return INDEX_TO_DISEASE_ID[index] ?? "unknown";
}
/**
* Get the model output index for a given disease ID.
* Returns -1 if not found.
*/
export function getIndexForDiseaseId(diseaseId: string): number {
return DISEASE_ID_TO_INDEX[diseaseId.toLowerCase()] ?? -1;
}
/**
* Check if a disease ID is a real disease (not "healthy" or "unknown").
*/
export function isRealDisease(diseaseId: string): boolean {
return diseaseId !== "healthy" && diseaseId !== "unknown";
}
/**
* Get the PlantVillage display name for a given model output index.
*/
export function getPlantVillageClassName(index: number): string {
return PLANTVILLAGE_CLASSES[index] ?? "unknown";
}
/**
* Get all known disease IDs (excluding "healthy" and "unknown").
*/
export function getAllDiseaseIds(): string[] {
const ids = new Set<string>();
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
if (id !== "healthy" && id !== "unknown") {
ids.add(id);
}
}
return Array.from(ids);
}

395
src/lib/ml/model-loader.ts Normal file
View File

@@ -0,0 +1,395 @@
/**
* Singleton model loader for the plant disease classifier.
*
* Lazy-loads the TF.js or ONNX model on first call and caches it in memory
* via globalThis for subsequent requests. Supports graceful fallback to
* mock mode when no model file is present.
*
* Model files expected at: public/models/plant-disease-classifier/model.json
*/
import fs from "fs/promises";
import fsSync from "fs";
import path from "path";
// ─── Types ───────────────────────────────────────────────────────────────────
/** Model runtime backend */
export type ModelBackend = "tfjs" | "onnx" | "mock";
/** Model loading status */
export interface ModelStatus {
/** Whether a real model is loaded */
loaded: boolean;
/** Backend being used */
backend: ModelBackend;
/** Model identifier string */
modelId: string;
/** Number of output classes */
numClasses: number;
/** Error message if loading failed */
error?: string;
}
/** Result from running the model on input data */
export interface ModelOutput {
/** Raw logits or probabilities from the model */
logits: Float32Array;
/** Inference time in milliseconds */
inferenceTimeMs: number;
}
/** Model interface abstracted over TF.js / ONNX / mock */
export interface PlantDiseaseModel {
/** Run inference on a preprocessed image tensor */
predict(tensor: Float32Array): Promise<ModelOutput>;
/** Get model metadata */
getStatus(): ModelStatus;
}
// ─── Constants ───────────────────────────────────────────────────────────────
/** Path to model files relative to project root */
const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier");
const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json");
/** Model identifier */
export const MODEL_ID = "plant-classifier-v1";
/** Maximum model load time (ms) */
const MODEL_LOAD_TIMEOUT = 30_000;
// ─── Global cache ────────────────────────────────────────────────────────────
declare global {
var __plantDiseaseModel__: PlantDiseaseModel | undefined;
var __plantDiseaseModelLoading__: Promise<PlantDiseaseModel> | undefined;
}
// ─── Model Loader ────────────────────────────────────────────────────────────
/**
* Get the cached model instance, loading it lazily on first call.
* Uses globalThis to persist across serverless invocations (within same container).
*
* @returns Promise resolving to the model (real or mock)
*/
export async function getModel(): Promise<PlantDiseaseModel> {
// Return cached model if available
if (globalThis.__plantDiseaseModel__) {
return globalThis.__plantDiseaseModel__;
}
// If already loading, wait for the existing promise
if (globalThis.__plantDiseaseModelLoading__) {
return globalThis.__plantDiseaseModelLoading__;
}
// Start loading
const loadingPromise = loadModel();
globalThis.__plantDiseaseModelLoading__ = loadingPromise;
try {
const model = await Promise.race([
loadingPromise,
new Promise<never>((_, reject) =>
setTimeout(
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
MODEL_LOAD_TIMEOUT,
),
),
]);
globalThis.__plantDiseaseModel__ = model;
return model;
} finally {
globalThis.__plantDiseaseModelLoading__ = undefined;
}
}
/**
* Load the model, attempting TF.js first, then ONNX, then falling back to mock.
*/
async function loadModel(): Promise<PlantDiseaseModel> {
// Check if model files exist
const modelExists = await checkModelFiles();
if (!modelExists) {
console.warn(
`[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` +
`Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`,
);
return createMockModel();
}
// Try TF.js first
try {
const tfModel = await tryLoadTFJS();
if (tfModel) {
console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`);
return tfModel;
}
} catch (err) {
console.warn(
`[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`,
);
}
// Try ONNX Runtime
try {
const onnxModel = await tryLoadONNX();
if (onnxModel) {
console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`);
return onnxModel;
}
} catch (err) {
console.warn(
`[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`,
);
}
// Fall back to mock
console.warn(`[model-loader] All backends failed. Using mock model.`);
return createMockModel();
}
/**
* Check if model files exist on disk.
*/
async function checkModelFiles(): Promise<boolean> {
try {
await fs.access(MODEL_JSON_PATH);
return true;
} catch {
return false;
}
}
// ─── TensorFlow.js Backend ───────────────────────────────────────────────────
/**
* Try to load the model using TensorFlow.js.
* Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs.
*/
async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let tf: any;
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
// eslint-disable-next-line @typescript-eslint/no-require-imports
const nodeUtil = require("util");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
return x === null || x === undefined;
};
}
// Try tfjs-node first (server-side, uses native bindings).
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
// as required dependencies — they are truly optional.
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const tfjsNode = await import("@tensorflow/tfjs-node" + "");
tf = tfjsNode;
} catch {
// Fall back to browser tfjs
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
tf = await import("@tensorflow/tfjs" + "");
} catch {
return null; // Neither tfjs package available
}
}
// Load the model from file path
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
const inputTensor = tf
.tensor3d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0])
.expandDims(0);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const outputTensor = (await model.predict(inputTensor)) as any;
const logits = new Float32Array(await outputTensor.data());
inputTensor.dispose();
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
outputTensor.dispose();
return {
logits,
inferenceTimeMs: performance.now() - startTime,
};
},
getStatus(): ModelStatus {
return {
loaded: true,
backend: "tfjs",
modelId: MODEL_ID,
numClasses: 38, // Original PlantVillage model
};
},
};
}
// ─── ONNX Runtime Backend ────────────────────────────────────────────────────
/**
* Try to load the model using ONNX Runtime.
*/
async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let ort: any;
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
ort = await import("onnxruntime-node" + "");
} catch {
return null;
}
// Look for .onnx file in model directory
const onnxPath = path.join(MODEL_DIR, "model.onnx");
const onnxExists = fsSync.existsSync(onnxPath);
if (!onnxExists) {
return null;
}
const session = await ort.InferenceSession.create(onnxPath);
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// ONNX expects NCHW format: [1, 3, 160, 160]
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
const feeds = { [session.inputNames[0]]: inputTensor };
const results = await session.run(feeds);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const outputValues = Object.values(results) as any[];
const logits = new Float32Array(outputValues[0].data);
inputTensor.dispose();
return {
logits,
inferenceTimeMs: performance.now() - startTime,
};
},
getStatus(): ModelStatus {
return {
loaded: true,
backend: "onnx",
modelId: MODEL_ID,
numClasses: 38,
};
},
};
}
// ─── Mock Model ──────────────────────────────────────────────────────────────
/**
* Create a deterministic mock model for development/demo mode.
*
* Generates reproducible predictions based on input tensor hash.
* This allows the UI to work without a real model file.
*/
function createMockModel(): PlantDiseaseModel {
return {
async predict(tensor: Float32Array): Promise<ModelOutput> {
// Simulate inference time (50-200ms)
const simulatedTime = 50 + Math.random() * 150;
await sleep(simulatedTime);
// Generate deterministic logits from input hash
const logits = generateMockLogits(tensor);
return {
logits,
inferenceTimeMs: simulatedTime,
};
},
getStatus(): ModelStatus {
return {
loaded: false,
backend: "mock",
modelId: MODEL_ID,
numClasses: 38,
error: "Model files not found. Running in demo mode with mock predictions.",
};
},
};
}
/**
* Generate deterministic mock logits from input tensor.
* Uses a simple hash of the first few tensor values to create
* reproducible but varied predictions.
*/
function generateMockLogits(tensor: Float32Array): Float32Array {
const numClasses = 38;
const logits = new Float32Array(numClasses);
// Simple hash of input for deterministic output
let hash = 0;
const sampleSize = Math.min(100, tensor.length);
for (let i = 0; i < sampleSize; i++) {
hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0;
}
// Generate logits using hash as seed
// Class 0 (healthy) gets a moderate score
logits[0] = (Math.abs(hash % 10) / 10) * 2;
// Give some disease classes higher scores
// This creates a realistic-looking distribution
for (let i = 1; i < numClasses - 1; i++) {
const seed = ((hash * (i + 1) * 7) % 1000) / 1000;
logits[i] = seed * 4 - 1; // Range roughly -1 to 3
}
// Make the top prediction more confident
const topIndex = Math.abs(hash % (numClasses - 2)) + 1;
logits[topIndex] = 3.5;
// Second highest
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
logits[secondIndex] = 2.5;
logits[numClasses - 1] = -2; // "unknown" gets low score
return logits;
}
/**
* Sleep for a given number of milliseconds.
*/
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
// ─── Reset (for testing) ─────────────────────────────────────────────────────
/**
* Reset the model cache. Useful for testing.
*/
export function resetModelCache(): void {
globalThis.__plantDiseaseModel__ = undefined;
globalThis.__plantDiseaseModelLoading__ = undefined;
}

View File

@@ -0,0 +1,55 @@
import { describe, it, expect, vi } from "vitest";
import { mimeTypeToExtension } from "./image-processing-server";
// Mock sharp dynamically
const mockSharp = vi.fn(() => ({
resize: vi.fn().mockReturnThis(),
jpeg: vi.fn().mockReturnThis(),
toBuffer: vi.fn().mockResolvedValue(Buffer.from("resized-image-data")),
}));
vi.doMock("sharp", () => ({
default: mockSharp,
}));
describe("mimeTypeToExtension", () => {
it("maps image/png to png", () => {
expect(mimeTypeToExtension("image/png")).toBe("png");
});
it("maps image/jpeg to jpg", () => {
expect(mimeTypeToExtension("image/jpeg")).toBe("jpg");
});
it("maps image/jpg to jpg", () => {
expect(mimeTypeToExtension("image/jpg")).toBe("jpg");
});
it("maps image/webp to webp", () => {
expect(mimeTypeToExtension("image/webp")).toBe("webp");
});
it("returns jpg for unknown mime types", () => {
expect(mimeTypeToExtension("image/bmp")).toBe("jpg");
expect(mimeTypeToExtension("unknown/type")).toBe("jpg");
});
});
describe("resizeImageServer", () => {
it("resizes image to specified dimensions", async () => {
const { resizeImageServer } = await import("./image-processing-server");
const buffer = Buffer.from("test-image-data");
const result = await resizeImageServer(buffer, 224, 224);
expect(result).toBeInstanceOf(Buffer);
expect(mockSharp).toHaveBeenCalled();
});
it("returns buffer for valid input", async () => {
const { resizeImageServer } = await import("./image-processing-server");
const buffer = Buffer.from("test-image-data");
const result = await resizeImageServer(buffer, 224, 224);
expect(result).toBeInstanceOf(Buffer);
});
});

View File

@@ -0,0 +1,51 @@
/**
* Server-only image processing helpers.
*
* These functions use Node.js native modules (sharp) and must NOT be
* imported by client components. They are used exclusively by API
* route handlers (server-side).
*/
// ─── Resize ──────────────────────────────────────────────────────────────────
/**
* Server-side image resize using sharp (if available) or a fallback.
* This is used by the upload API route.
*
* @param buffer - Raw image buffer
* @param size - Target dimension
* @returns Promise<Buffer> resized image as JPEG
*/
export async function resizeImageServer(
buffer: Buffer,
size: number,
): Promise<Buffer> {
try {
const sharpModule = await import("sharp");
return sharpModule.default(buffer)
.resize(size, size)
.jpeg({ quality: 95 })
.toBuffer();
} catch {
// Fallback: return original buffer if sharp is not available
// In production, sharp should be installed
throw new Error(
"sharp is required for server-side image resizing. Install with: npm install sharp",
);
}
}
// ─── MIME Type Helpers ───────────────────────────────────────────────────────
/**
* Generate a file extension from a MIME type.
*/
export function mimeTypeToExtension(mimeType: string): string {
const map: Record<string, string> = {
"image/png": "png",
"image/jpeg": "jpg",
"image/jpg": "jpg",
"image/webp": "webp",
};
return map[mimeType] || "jpg";
}

191
src/lib/types.ts Normal file
View File

@@ -0,0 +1,191 @@
/**
* Shared TypeScript interfaces for the Plant Disease Knowledge Base.
* Used by seed data, API helpers, and API routes.
*/
/** Types of causal agents that cause plant diseases */
export type CausalAgentType = "fungal" | "bacterial" | "viral" | "environmental";
/** Severity levels for plant diseases */
export type Severity = "low" | "moderate" | "high" | "critical";
/** How common/prevalent a disease is in the field */
export type Prevalence = "common" | "uncommon" | "rare" | "very_rare";
/** Plant category for grouping and filtering */
export type PlantCategory =
| "vegetable"
| "herb"
| "houseplant"
| "flower"
| "fruit"
| "succulent"
| "tree";
/**
* A plant entry in the knowledge base.
* Each plant has 0+ associated diseases (linked via plantId in Disease).
*/
export interface Plant {
/** Unique identifier (slug), e.g., "tomato", "monstera" */
id: string;
/** Common name, e.g., "Tomato" */
commonName: string;
/** Scientific (botanical) name, e.g., "Solanum lycopersicum" */
scientificName: string;
/** Plant family, e.g., "Solanaceae" */
family: string;
/** Category for filtering */
category: PlantCategory;
/** Brief care summary (light, water, humidity, temperature) */
careSummary: string;
/** URL to a representative image of the healthy plant */
imageUrl: string;
}
/**
* A disease entry in the knowledge base.
* Links to a plant via plantId and can reference lookalike diseases.
*/
export interface Disease {
/** Unique identifier (slug), e.g., "early-blight" */
id: string;
/** ID of the affected plant */
plantId: string;
/** Common disease name, e.g., "Early Blight" */
name: string;
/** Scientific name of the pathogen (if applicable) */
scientificName: string;
/** Type of causal agent */
causalAgentType: CausalAgentType;
/** Detailed description of the disease */
description: string;
/** Observable symptoms (≥3) */
symptoms: string[];
/** Root causes or contributing factors (≥2) */
causes: string[];
/** Step-by-step treatment instructions (≥3) */
treatment: string[];
/** Prevention tips (≥2) */
prevention: string[];
/** IDs of diseases that look similar and may be confused with this one */
lookalikeDiseaseIds: string[];
/** Overall severity of the disease */
severity: Severity;
/** How common/prevalent this disease is */
prevalence: Prevalence;
/** URL to a representative image showing disease symptoms */
imageUrl?: string;
}
/** Query parameters for listing/searching plants */
export interface PlantListParams {
search?: string;
category?: PlantCategory;
}
/** Query parameters for listing/searching diseases */
export interface DiseaseListParams {
plantId?: string;
search?: string;
causalAgentType?: CausalAgentType;
severity?: Severity;
}
/** Response wrapper for a single plant with its diseases */
export interface PlantWithDiseases {
plant: Plant;
diseases: Disease[];
}
/** Response wrapper for a single disease with its plant */
export interface DiseaseWithPlant {
disease: Disease;
plant: Plant;
}
/** Standard error response shape */
export interface ApiError {
error: string;
message: string;
status: number;
}
/** Paginated list response (future-proof, currently returns all) */
export interface PaginatedResponse<T> {
items: T[];
total: number;
}
// ─── ML / Inference types ────────────────────────────────────────────────────
/** Confidence label based on calibrated score */
export type ConfidenceLabel = "high" | "medium" | "low";
/** Raw prediction from model inference (before calibration) */
export interface RawPrediction {
/** Model output class index */
classIndex: number;
/** Raw softmax probability */
probability: number;
}
/** Calibrated confidence score with label */
export interface ConfidenceResult {
/** Raw softmax probability from model */
raw: number;
/** Calibrated/adjusted confidence score */
adjusted: number;
/** Human-readable confidence label */
label: ConfidenceLabel;
}
/** A single prediction in the identify API response */
export interface PredictionResult {
/** Disease ID matching knowledge base */
diseaseId: string;
/** Full disease object from knowledge base */
disease: Disease;
/** Calibrated confidence */
confidence: ConfidenceResult;
/** IDs of lookalike diseases that could be confused with this one */
lookalikes: string[];
/** Full disease objects for lookalikes, pre-resolved server-side */
lookalikeDiseases: Disease[];
/** The plant this disease affects (included for client convenience) */
plant: Plant | null;
}
/** Metadata about the inference run */
export interface InferenceMetadata {
/** Model identifier/version */
model: string;
/** Inference time in milliseconds */
inferenceTimeMs: number;
/** Image ID that was analyzed */
imageId: string;
}
/** Response from POST /api/identify */
export interface IdentifyResponse {
/** Ranked predictions */
predictions: PredictionResult[];
/** Inference metadata */
metadata: InferenceMetadata;
/** True when running in demo/mock mode (no real model loaded) */
demo_mode?: boolean;
}
/** Request body for POST /api/identify */
export interface IdentifyRequest {
/** Image ID from a previous /api/upload call */
imageId: string;
}
/** Result from runInference() */
export interface InferenceResult {
/** Top-K raw predictions sorted by probability descending */
predictions: RawPrediction[];
/** Inference time in milliseconds */
inferenceTimeMs: number;
}