re-init
This commit is contained in:
0
src/lib/.gitkeep
Normal file
0
src/lib/.gitkeep
Normal file
0
src/lib/api/.gitkeep
Normal file
0
src/lib/api/.gitkeep
Normal file
135
src/lib/api/browse.ts
Normal file
135
src/lib/api/browse.ts
Normal file
@@ -0,0 +1,135 @@
|
||||
/**
|
||||
* Browse API — fetches plants with disease counts from the Turso DB
|
||||
* for the browse page. Runs server-side only.
|
||||
*/
|
||||
|
||||
import { sql, eq, inArray, notInArray } from "drizzle-orm";
|
||||
import { getDb } from "@/lib/db/index";
|
||||
import { plants, diseases, plantViews } from "@/lib/db/schema";
|
||||
import type { PlantCardData } from "@/components/PlantCard";
|
||||
|
||||
export type { PlantCardData };
|
||||
|
||||
/**
|
||||
* Get all plants with their disease counts for the browse page.
|
||||
*
|
||||
* Uses scalar subqueries for COUNT to avoid expensive LEFT JOIN + GROUP BY
|
||||
* on the large diseases table (11,498 rows).
|
||||
*/
|
||||
export async function getBrowsePlants(): Promise<PlantCardData[]> {
|
||||
const db = getDb();
|
||||
|
||||
const rows = await db
|
||||
.select({
|
||||
id: plants.id,
|
||||
commonName: plants.commonName,
|
||||
scientificName: plants.scientificName,
|
||||
family: plants.family,
|
||||
category: plants.category,
|
||||
imageUrl: plants.imageUrl,
|
||||
updatedAt: plants.updatedAt,
|
||||
viewCount: sql<number>`COALESCE(${plantViews.viewCount}, 0)`,
|
||||
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
|
||||
})
|
||||
.from(plants)
|
||||
.leftJoin(plantViews, eq(plantViews.plantId, plants.id))
|
||||
.orderBy(plants.commonName);
|
||||
|
||||
return rows.map((r) => ({
|
||||
id: r.id,
|
||||
commonName: r.commonName,
|
||||
scientificName: r.scientificName,
|
||||
family: r.family,
|
||||
category: r.category,
|
||||
imageUrl: r.imageUrl,
|
||||
updatedAt: r.updatedAt,
|
||||
viewCount: r.viewCount,
|
||||
diseaseCount: r.diseaseCount,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a single plant with disease count (for detail page lookups).
|
||||
*/
|
||||
export async function getBrowsePlant(id: string): Promise<PlantCardData | null> {
|
||||
const db = getDb();
|
||||
const rows = await db
|
||||
.select({
|
||||
id: plants.id,
|
||||
commonName: plants.commonName,
|
||||
scientificName: plants.scientificName,
|
||||
family: plants.family,
|
||||
category: plants.category,
|
||||
imageUrl: plants.imageUrl,
|
||||
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
|
||||
})
|
||||
.from(plants)
|
||||
.where(eq(plants.id, id))
|
||||
.limit(1);
|
||||
|
||||
return rows[0] ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get featured plants for the homepage (subset).
|
||||
*/
|
||||
const FEATURED_IDS = [
|
||||
"tomato",
|
||||
"basil",
|
||||
"rose",
|
||||
"monstera",
|
||||
"snake-plant",
|
||||
"pepper",
|
||||
"apple",
|
||||
"corn",
|
||||
"wheat",
|
||||
"strawberry",
|
||||
"blueberry",
|
||||
"lettuce",
|
||||
];
|
||||
|
||||
export async function getFeaturedPlants(): Promise<PlantCardData[]> {
|
||||
const db = getDb();
|
||||
|
||||
const selectFeatured = db
|
||||
.select({
|
||||
id: plants.id,
|
||||
commonName: plants.commonName,
|
||||
scientificName: plants.scientificName,
|
||||
family: plants.family,
|
||||
category: plants.category,
|
||||
imageUrl: plants.imageUrl,
|
||||
updatedAt: plants.updatedAt,
|
||||
viewCount: sql<number>`COALESCE(${plantViews.viewCount}, 0)`,
|
||||
diseaseCount: sql<number>`(SELECT COUNT(*) FROM ${diseases} WHERE ${diseases.plantId} = ${plants.id})`,
|
||||
})
|
||||
.from(plants)
|
||||
.leftJoin(plantViews, eq(plantViews.plantId, plants.id));
|
||||
|
||||
const rows = await selectFeatured
|
||||
.where(inArray(plants.id, FEATURED_IDS))
|
||||
.orderBy(plants.commonName);
|
||||
|
||||
if (rows.length < 6) {
|
||||
const padRows = await selectFeatured
|
||||
.where(notInArray(plants.id, FEATURED_IDS))
|
||||
.orderBy(plants.commonName)
|
||||
.limit(12 - rows.length);
|
||||
return [...rows, ...padRows].map(mapRow);
|
||||
}
|
||||
return rows.slice(0, 12).map(mapRow);
|
||||
}
|
||||
|
||||
function mapRow(r: Record<string, unknown>): PlantCardData {
|
||||
return {
|
||||
id: r.id as string,
|
||||
commonName: r.commonName as string,
|
||||
scientificName: r.scientificName as string,
|
||||
family: r.family as string,
|
||||
category: r.category as string,
|
||||
imageUrl: r.imageUrl as string,
|
||||
updatedAt: r.updatedAt as string | undefined,
|
||||
viewCount: r.viewCount as number,
|
||||
diseaseCount: r.diseaseCount as number,
|
||||
};
|
||||
}
|
||||
382
src/lib/api/diseases-db.ts
Normal file
382
src/lib/api/diseases-db.ts
Normal file
@@ -0,0 +1,382 @@
|
||||
/**
|
||||
* Typed helpers to query the Plant Disease Knowledge Base from Turso DB.
|
||||
*
|
||||
* All functions are async and use Drizzle ORM against the Turso/libSQL database.
|
||||
*
|
||||
* For client components that need sync access, import from
|
||||
* @/lib/api/diseases-sync.ts (backed by JSON seed data)
|
||||
*/
|
||||
|
||||
import { eq, like, or, and, sql, type SQL } from "drizzle-orm";
|
||||
import { getDb } from "@/lib/db/index";
|
||||
import { plants, diseases } from "@/lib/db/schema";
|
||||
import type {
|
||||
CausalAgentType,
|
||||
Disease,
|
||||
DiseaseListParams,
|
||||
DiseaseWithPlant,
|
||||
Plant,
|
||||
PlantListParams,
|
||||
PlantWithDiseases,
|
||||
Prevalence,
|
||||
Severity,
|
||||
PlantCategory,
|
||||
} from "@/lib/types";
|
||||
|
||||
// ─── Row → Type mappers ──────────────────────────────────────────────────────
|
||||
|
||||
function toPlant(row: typeof plants.$inferSelect): Plant {
|
||||
return {
|
||||
id: row.id,
|
||||
commonName: row.commonName,
|
||||
scientificName: row.scientificName,
|
||||
family: row.family,
|
||||
category: row.category as PlantCategory,
|
||||
careSummary: row.careSummary,
|
||||
imageUrl: row.imageUrl,
|
||||
};
|
||||
}
|
||||
|
||||
function toDisease(row: typeof diseases.$inferSelect): Disease {
|
||||
return {
|
||||
id: row.id,
|
||||
plantId: row.plantId,
|
||||
name: row.name,
|
||||
scientificName: row.scientificName,
|
||||
causalAgentType: row.causalAgentType as CausalAgentType,
|
||||
description: row.description,
|
||||
symptoms: row.symptoms as string[],
|
||||
causes: row.causes as string[],
|
||||
treatment: row.treatment as string[],
|
||||
prevention: row.prevention as string[],
|
||||
lookalikeDiseaseIds: (row.lookalikeIds as string[]) ?? [],
|
||||
severity: row.severity as Severity,
|
||||
prevalence: (row.prevalence as Prevalence) ?? "uncommon",
|
||||
imageUrl: (row.imageUrl as string) || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get a plant by its ID.
|
||||
*/
|
||||
export async function getPlantById(id: string): Promise<Plant | undefined> {
|
||||
const db = getDb();
|
||||
const row = await db.select().from(plants).where(eq(plants.id, id.toLowerCase())).limit(1);
|
||||
return row[0] ? toPlant(row[0]) : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a disease by its ID.
|
||||
*/
|
||||
export async function getDiseaseById(id: string): Promise<Disease | undefined> {
|
||||
const db = getDb();
|
||||
const row = await db.select().from(diseases).where(eq(diseases.id, id.toLowerCase())).limit(1);
|
||||
return row[0] ? toDisease(row[0]) : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all diseases for a specific plant.
|
||||
*/
|
||||
export async function getDiseasesByPlantId(plantId: string): Promise<Disease[]> {
|
||||
const db = getDb();
|
||||
const rows = await db.select().from(diseases).where(eq(diseases.plantId, plantId.toLowerCase()));
|
||||
return rows.map(toDisease);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a plant with all its associated diseases.
|
||||
*/
|
||||
export async function getPlantWithDiseases(
|
||||
plantId: string,
|
||||
): Promise<PlantWithDiseases | undefined> {
|
||||
const plant = await getPlantById(plantId);
|
||||
if (!plant) return undefined;
|
||||
const diseaseRows = await getDiseasesByPlantId(plantId);
|
||||
return { plant, diseases: diseaseRows };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a disease with its associated plant.
|
||||
*/
|
||||
export async function getDiseaseWithPlant(
|
||||
diseaseId: string,
|
||||
): Promise<DiseaseWithPlant | undefined> {
|
||||
const disease = await getDiseaseById(diseaseId);
|
||||
if (!disease) return undefined;
|
||||
const plant = await getPlantById(disease.plantId);
|
||||
if (!plant) return undefined;
|
||||
return { disease, plant };
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve lookalike disease IDs to full disease objects.
|
||||
*/
|
||||
export async function getLookalikeDiseases(diseaseId: string): Promise<Disease[]> {
|
||||
const disease = await getDiseaseById(diseaseId);
|
||||
if (!disease || !disease.lookalikeDiseaseIds.length) return [];
|
||||
const db = getDb();
|
||||
const ids = disease.lookalikeDiseaseIds;
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(diseases)
|
||||
.where(sql`${diseases.id} IN ${ids}`);
|
||||
return rows.map(toDisease);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search plants by term (matches common name, scientific name, family, category).
|
||||
*/
|
||||
export async function searchPlants(term: string): Promise<Plant[]> {
|
||||
const lower = term.toLowerCase().trim();
|
||||
if (!lower) return listPlants();
|
||||
const db = getDb();
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(plants)
|
||||
.where(
|
||||
or(
|
||||
like(plants.commonName, `%${lower}%`),
|
||||
like(plants.scientificName, `%${lower}%`),
|
||||
like(plants.family, `%${lower}%`),
|
||||
like(plants.category, `%${lower}%`),
|
||||
),
|
||||
);
|
||||
return rows.map(toPlant);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search diseases by term (matches name, scientific name, description, symptoms via LIKE).
|
||||
*/
|
||||
export async function searchDiseases(term: string): Promise<Disease[]> {
|
||||
const lower = term.toLowerCase().trim();
|
||||
if (!lower) return listDiseases();
|
||||
const db = getDb();
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(diseases)
|
||||
.where(
|
||||
or(
|
||||
like(diseases.name, `%${lower}%`),
|
||||
like(diseases.scientificName, `%${lower}%`),
|
||||
like(diseases.description, `%${lower}%`),
|
||||
),
|
||||
);
|
||||
return rows.map(toDisease);
|
||||
}
|
||||
|
||||
/**
|
||||
* List plants with optional search and category filters.
|
||||
*/
|
||||
export async function listPlants(params: PlantListParams = {}): Promise<Plant[]> {
|
||||
const db = getDb();
|
||||
const plantConditions: SQL[] = [];
|
||||
|
||||
if (params.category) {
|
||||
plantConditions.push(eq(plants.category, params.category));
|
||||
}
|
||||
if (params.search) {
|
||||
const lower = params.search.toLowerCase().trim();
|
||||
const searchCond = or(
|
||||
like(plants.commonName, `%${lower}%`),
|
||||
like(plants.scientificName, `%${lower}%`),
|
||||
like(plants.family, `%${lower}%`),
|
||||
like(plants.category, `%${lower}%`),
|
||||
);
|
||||
if (searchCond) plantConditions.push(searchCond);
|
||||
}
|
||||
|
||||
const query =
|
||||
plantConditions.length > 0
|
||||
? db
|
||||
.select()
|
||||
.from(plants)
|
||||
.where(and(...plantConditions))
|
||||
: db.select().from(plants);
|
||||
|
||||
const rows = await query;
|
||||
return rows.map(toPlant);
|
||||
}
|
||||
|
||||
/**
|
||||
* List diseases with optional filters.
|
||||
*/
|
||||
export async function listDiseases(params: DiseaseListParams = {}): Promise<Disease[]> {
|
||||
const db = getDb();
|
||||
const diseaseConditions: SQL[] = [];
|
||||
|
||||
if (params.plantId) {
|
||||
diseaseConditions.push(eq(diseases.plantId, params.plantId.toLowerCase()));
|
||||
}
|
||||
if (params.causalAgentType) {
|
||||
diseaseConditions.push(eq(diseases.causalAgentType, params.causalAgentType));
|
||||
}
|
||||
if (params.severity) {
|
||||
diseaseConditions.push(eq(diseases.severity, params.severity));
|
||||
}
|
||||
if (params.search) {
|
||||
const lower = params.search.toLowerCase().trim();
|
||||
const searchCond = or(
|
||||
like(diseases.name, `%${lower}%`),
|
||||
like(diseases.scientificName, `%${lower}%`),
|
||||
like(diseases.description, `%${lower}%`),
|
||||
);
|
||||
if (searchCond) diseaseConditions.push(searchCond);
|
||||
}
|
||||
|
||||
const query =
|
||||
diseaseConditions.length > 0
|
||||
? db
|
||||
.select()
|
||||
.from(diseases)
|
||||
.where(and(...diseaseConditions))
|
||||
: db.select().from(diseases);
|
||||
|
||||
const rows = await query;
|
||||
return rows.map(toDisease);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all unique plant IDs that have diseases.
|
||||
*/
|
||||
export async function getPlantIdsWithDiseases(): Promise<string[]> {
|
||||
const db = getDb();
|
||||
const rows = await db
|
||||
.select({ plantId: diseases.plantId })
|
||||
.from(diseases)
|
||||
.groupBy(diseases.plantId);
|
||||
return rows.map((r) => r.plantId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all unique disease IDs referenced as lookalikes.
|
||||
*/
|
||||
export async function getReferencedLookalikeIds(): Promise<Set<string>> {
|
||||
const db = getDb();
|
||||
const rows = await db
|
||||
.select({ id: diseases.id, lookalikeIds: diseases.lookalikeIds })
|
||||
.from(diseases);
|
||||
const ids = new Set<string>();
|
||||
for (const row of rows) {
|
||||
const lookalikes = row.lookalikeIds as string[];
|
||||
for (const id of lookalikes) {
|
||||
ids.add(id);
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate knowledge base data integrity.
|
||||
* Returns array of validation errors (empty = valid).
|
||||
*/
|
||||
export async function validateKnowledgeBase(): Promise<string[]> {
|
||||
const errors: string[] = [];
|
||||
const validCausalAgentTypes: CausalAgentType[] = [
|
||||
"fungal",
|
||||
"bacterial",
|
||||
"viral",
|
||||
"environmental",
|
||||
];
|
||||
const validSeverities: Severity[] = ["low", "moderate", "high", "critical"];
|
||||
const validPrevalences: Prevalence[] = ["common", "uncommon", "rare", "very_rare"];
|
||||
|
||||
const db = getDb();
|
||||
|
||||
// Get all plants and diseases
|
||||
const allPlants = await db.select({ id: plants.id }).from(plants);
|
||||
const allDiseases = await db.select().from(diseases);
|
||||
|
||||
const plantIds = new Set(allPlants.map((p) => p.id));
|
||||
const diseaseIds = new Set<string>();
|
||||
const diseaseMap = new Map<string, (typeof allDiseases)[0]>();
|
||||
const diseaseErrors: Array<{
|
||||
id: string;
|
||||
plantId: string;
|
||||
name: string;
|
||||
lookalikeDiseaseIds: string[];
|
||||
}> = [];
|
||||
|
||||
for (const d of allDiseases) {
|
||||
if (diseaseIds.has(d.id)) {
|
||||
errors.push(`Duplicate disease ID: ${d.id}`);
|
||||
}
|
||||
diseaseIds.add(d.id);
|
||||
diseaseMap.set(d.id, d);
|
||||
diseaseErrors.push({
|
||||
id: d.id,
|
||||
plantId: d.plantId,
|
||||
name: d.name,
|
||||
lookalikeDiseaseIds: (d.lookalikeIds as string[]) ?? [],
|
||||
});
|
||||
}
|
||||
|
||||
// Check disease references
|
||||
for (const d of diseaseErrors) {
|
||||
// Valid plant reference
|
||||
if (!plantIds.has(d.plantId)) {
|
||||
errors.push(`Disease "${d.id}" references unknown plant ID: ${d.plantId}`);
|
||||
}
|
||||
|
||||
const full = diseaseMap.get(d.id)!;
|
||||
|
||||
// Valid causal agent type
|
||||
if (!validCausalAgentTypes.includes(full.causalAgentType as CausalAgentType)) {
|
||||
errors.push(`Disease "${d.id}" has invalid causalAgentType: ${full.causalAgentType}`);
|
||||
}
|
||||
|
||||
// Valid severity
|
||||
if (!validSeverities.includes(full.severity as Severity)) {
|
||||
errors.push(`Disease "${d.id}" has invalid severity: ${full.severity}`);
|
||||
}
|
||||
|
||||
// Valid prevalence
|
||||
if (full.prevalence && !validPrevalences.includes(full.prevalence as Prevalence)) {
|
||||
errors.push(`Disease "${d.id}" has invalid prevalence: ${full.prevalence}`);
|
||||
}
|
||||
|
||||
// Minimum counts
|
||||
const symptoms = full.symptoms as string[];
|
||||
const causes = full.causes as string[];
|
||||
const treatment = full.treatment as string[];
|
||||
const prevention = full.prevention as string[];
|
||||
|
||||
if (symptoms.length < 3) {
|
||||
errors.push(`Disease "${d.id}" has fewer than 3 symptoms (${symptoms.length})`);
|
||||
}
|
||||
if (causes.length < 2) {
|
||||
errors.push(`Disease "${d.id}" has fewer than 2 causes (${causes.length})`);
|
||||
}
|
||||
if (treatment.length < 3) {
|
||||
errors.push(`Disease "${d.id}" has fewer than 3 treatment steps (${treatment.length})`);
|
||||
}
|
||||
if (prevention.length < 2) {
|
||||
errors.push(`Disease "${d.id}" has fewer than 2 prevention tips (${prevention.length})`);
|
||||
}
|
||||
|
||||
// Valid lookalike references
|
||||
for (const lookalikeId of d.lookalikeDiseaseIds) {
|
||||
if (!diseaseIds.has(lookalikeId)) {
|
||||
errors.push(`Disease "${d.id}" references unknown lookalike: ${lookalikeId}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check lookalike bidirectionality
|
||||
for (const d of diseaseErrors) {
|
||||
for (const lookalikeId of d.lookalikeDiseaseIds) {
|
||||
const lookalike = diseaseMap.get(lookalikeId);
|
||||
if (lookalike) {
|
||||
const otherLookalikes = (lookalike.lookalikeIds as string[]) ?? [];
|
||||
if (!otherLookalikes.includes(d.id)) {
|
||||
errors.push(
|
||||
`Lookalike reference not bidirectional: "${d.id}" references "${lookalikeId}" but not vice versa`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
44
src/lib/api/home.ts
Normal file
44
src/lib/api/home.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Homepage data — fetches featured plants from the Turso DB.
|
||||
* Uses React's cache() to ensure one fetch per render pass.
|
||||
* Backed by the async fetch for SSR but stays sync in exported interface
|
||||
* via a module-level cache pattern.
|
||||
*/
|
||||
|
||||
import { unstable_cache } from "next/cache";
|
||||
|
||||
// Re-export the type for convenience
|
||||
export type { PlantCardData } from "@/components/PlantCard";
|
||||
|
||||
/**
|
||||
* Get featured plants for the homepage.
|
||||
* Cached via next/cache to avoid repeated DB calls.
|
||||
*/
|
||||
export const getFeaturedPlants = unstable_cache(
|
||||
async () => {
|
||||
const { getBrowsePlants } = await import("./browse");
|
||||
const all = await getBrowsePlants();
|
||||
const FEATURED_IDS = [
|
||||
"tomato",
|
||||
"basil",
|
||||
"rose",
|
||||
"monstera",
|
||||
"snake-plant",
|
||||
"pepper",
|
||||
"apple",
|
||||
"corn",
|
||||
"wheat",
|
||||
"strawberry",
|
||||
"blueberry",
|
||||
"lettuce",
|
||||
];
|
||||
const featured = all.filter((p) => FEATURED_IDS.includes(p.id));
|
||||
if (featured.length < 6) {
|
||||
const rest = all.filter((p) => !FEATURED_IDS.includes(p.id));
|
||||
return [...featured, ...rest].slice(0, 12);
|
||||
}
|
||||
return featured.slice(0, 12);
|
||||
},
|
||||
["featured-plants"],
|
||||
{ revalidate: 3600 },
|
||||
);
|
||||
107
src/lib/api/identify-client.test.ts
Normal file
107
src/lib/api/identify-client.test.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { identifyPlant } from "./identify";
|
||||
|
||||
// Mock global fetch
|
||||
const mockFetch = vi.fn();
|
||||
global.fetch = mockFetch;
|
||||
|
||||
describe("identifyPlant", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("identifies plant and returns predictions", async () => {
|
||||
const mockResponse = {
|
||||
predictions: [
|
||||
{
|
||||
diseaseId: "early-blight",
|
||||
disease: {
|
||||
id: "early-blight",
|
||||
name: "Early Blight",
|
||||
causalAgent: "Alternaria solani",
|
||||
causalAgentType: "fungal",
|
||||
severity: "moderate",
|
||||
symptoms: ["Dark spots"],
|
||||
treatment: ["Remove leaves"],
|
||||
lookalikeDiseaseIds: [],
|
||||
plantId: "tomato",
|
||||
},
|
||||
confidence: { raw: 0.85, adjusted: 0.82 },
|
||||
lookalikes: [],
|
||||
},
|
||||
],
|
||||
metadata: {
|
||||
model: "mock-model",
|
||||
inferenceTimeMs: 150,
|
||||
imageId: "test-image-123",
|
||||
},
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => mockResponse,
|
||||
});
|
||||
|
||||
const result = await identifyPlant("test-image-123");
|
||||
expect(result).toEqual(mockResponse);
|
||||
});
|
||||
|
||||
it("calls fetch with correct URL and method", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ predictions: [], metadata: {} }),
|
||||
});
|
||||
|
||||
await identifyPlant("test-id");
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining("/api/identify"),
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("sends imageId in request body", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ predictions: [], metadata: {} }),
|
||||
});
|
||||
|
||||
await identifyPlant("test-id");
|
||||
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const body = JSON.parse(callArgs.body);
|
||||
expect(body.imageId).toBe("test-id");
|
||||
});
|
||||
|
||||
it("throws error when response is not ok", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
statusText: "Internal Server Error",
|
||||
});
|
||||
|
||||
await expect(identifyPlant("test-id")).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("throws error when fetch fails", async () => {
|
||||
mockFetch.mockRejectedValue(new Error("Network error"));
|
||||
|
||||
await expect(identifyPlant("test-id")).rejects.toThrow("Network error");
|
||||
});
|
||||
|
||||
it("handles demo mode response", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
predictions: [],
|
||||
metadata: {},
|
||||
demo_mode: true,
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await identifyPlant("test-id");
|
||||
expect(result.demo_mode).toBe(true);
|
||||
});
|
||||
});
|
||||
49
src/lib/api/identify.ts
Normal file
49
src/lib/api/identify.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Client-side API helper for plant disease identification.
|
||||
*
|
||||
* POSTs an imageId to the /api/identify endpoint and returns
|
||||
* ranked predictions with confidence scores, enriched with
|
||||
* knowledge base data (name, symptoms, treatment, prevention).
|
||||
*/
|
||||
|
||||
import type { IdentifyRequest, IdentifyResponse } from "@/lib/types";
|
||||
|
||||
export interface IdentifyError {
|
||||
error: string;
|
||||
message: string;
|
||||
status: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Identify plant diseases from an uploaded image.
|
||||
*
|
||||
* @param imageId - Image ID from a previous /api/upload call
|
||||
* @returns IdentifyResponse with ranked predictions and metadata
|
||||
* @throws IdentifyError on failure
|
||||
*/
|
||||
export async function identifyPlant(
|
||||
imageId: string,
|
||||
): Promise<IdentifyResponse> {
|
||||
const request: IdentifyRequest = { imageId };
|
||||
|
||||
const response = await fetch("/api/identify", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal: AbortSignal.timeout(30_000), // 30s timeout for inference
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw {
|
||||
error: data.error || "Identification failed",
|
||||
message: data.message || `Server returned ${response.status}`,
|
||||
status: response.status,
|
||||
} as IdentifyError;
|
||||
}
|
||||
|
||||
return data as IdentifyResponse;
|
||||
}
|
||||
98
src/lib/api/upload-client.test.ts
Normal file
98
src/lib/api/upload-client.test.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { uploadImage } from "./upload";
|
||||
import * as imageProcessing from "@/lib/image-processing";
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock("@/lib/image-processing", () => ({
|
||||
validateImageFile: vi.fn(() => ({ ok: true })),
|
||||
validateImageDimensions: vi.fn(() => Promise.resolve({ ok: true })),
|
||||
}));
|
||||
|
||||
// Mock global fetch
|
||||
const mockFetch = vi.fn();
|
||||
global.fetch = mockFetch;
|
||||
|
||||
describe("uploadImage", () => {
|
||||
const mockFile = new File(["dummy"], "test.png", { type: "image/png" });
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset mocks to default pass values
|
||||
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: true });
|
||||
(imageProcessing.validateImageDimensions as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
|
||||
});
|
||||
|
||||
it("uploads image and returns response", async () => {
|
||||
const mockResponse = {
|
||||
imageId: "test-id-123",
|
||||
tensorShape: [3, 224, 224],
|
||||
previewUrl: "/uploads/test-id-123.png",
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => mockResponse,
|
||||
});
|
||||
|
||||
const result = await uploadImage(mockFile);
|
||||
expect(result).toEqual(mockResponse);
|
||||
});
|
||||
|
||||
it("calls fetch with correct URL", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ imageId: "test", tensorShape: [3, 224, 224], previewUrl: "/test.png" }),
|
||||
});
|
||||
|
||||
await uploadImage(mockFile);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
"/api/upload",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("sends FormData with image field", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ imageId: "test", tensorShape: [3, 224, 224], previewUrl: "/test.png" }),
|
||||
});
|
||||
|
||||
await uploadImage(mockFile);
|
||||
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
expect(callArgs.body).toBeInstanceOf(FormData);
|
||||
});
|
||||
|
||||
it("throws error when response is not ok", async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: false,
|
||||
status: 413,
|
||||
json: async () => ({ error: "File too large", message: "File exceeds 10MB limit" }),
|
||||
});
|
||||
|
||||
await expect(uploadImage(mockFile)).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("throws error when fetch fails", async () => {
|
||||
mockFetch.mockRejectedValue(new Error("Network error"));
|
||||
|
||||
await expect(uploadImage(mockFile)).rejects.toThrow("Network error");
|
||||
});
|
||||
|
||||
it("throws error when file validation fails", async () => {
|
||||
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: false, error: "Invalid file type" });
|
||||
|
||||
await expect(uploadImage(mockFile)).rejects.toThrow("Validation: Invalid file type");
|
||||
});
|
||||
|
||||
it("throws error when dimension validation fails", async () => {
|
||||
// Reset validateImageFile to pass so we can test dimension validation
|
||||
(imageProcessing.validateImageFile as ReturnType<typeof vi.fn>).mockReturnValue({ ok: true });
|
||||
(imageProcessing.validateImageDimensions as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: false, error: "Image too small" });
|
||||
|
||||
await expect(uploadImage(mockFile)).rejects.toThrow("Validation: Image too small");
|
||||
});
|
||||
});
|
||||
89
src/lib/api/upload.ts
Normal file
89
src/lib/api/upload.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* Client-side API helper for image upload.
|
||||
*
|
||||
* POSTs an image file to the server upload endpoint and returns
|
||||
* the image metadata (imageId, tensorShape, previewUrl).
|
||||
*/
|
||||
|
||||
import { validateImageFile, validateImageDimensions } from "@/lib/image-processing";
|
||||
|
||||
export interface UploadResponse {
|
||||
imageId: string;
|
||||
tensorShape: [number, number, number, number];
|
||||
previewUrl: string;
|
||||
}
|
||||
|
||||
export interface UploadError {
|
||||
error: string;
|
||||
message: string;
|
||||
status: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload an image file to the server.
|
||||
*
|
||||
* Validates the file client-side before sending. Returns the server
|
||||
* response with imageId, tensorShape, and previewUrl.
|
||||
*
|
||||
* @param file - The image file to upload
|
||||
* @param onProgress - Optional callback for upload progress (0–100)
|
||||
* @returns UploadResponse on success
|
||||
* @throws UploadError on failure
|
||||
*/
|
||||
export async function uploadImage(
|
||||
file: File,
|
||||
onProgress?: (percent: number) => void,
|
||||
): Promise<UploadResponse> {
|
||||
// Client-side validation
|
||||
const fileValidation = validateImageFile(file);
|
||||
if (!fileValidation.ok) {
|
||||
throw new Error(`Validation: ${fileValidation.error}`);
|
||||
}
|
||||
|
||||
const dimValidation = await validateImageDimensions(file);
|
||||
if (!dimValidation.ok) {
|
||||
throw new Error(`Validation: ${dimValidation.error}`);
|
||||
}
|
||||
|
||||
// Build form data
|
||||
const formData = new FormData();
|
||||
formData.append("image", file);
|
||||
|
||||
// POST with progress tracking
|
||||
const response = await fetch("/api/upload", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal: AbortSignal.timeout(30_000), // 30s timeout
|
||||
});
|
||||
|
||||
// Parse response
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw {
|
||||
error: data.error || "Upload failed",
|
||||
message: data.message || `Server returned ${response.status}`,
|
||||
status: response.status,
|
||||
} as UploadError;
|
||||
}
|
||||
|
||||
return data as UploadResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload an image and also get the preprocessed tensor for client-side inference.
|
||||
* Useful when you want to run inference on the client while also uploading.
|
||||
*
|
||||
* @param file - The image file to upload
|
||||
* @returns { uploadResponse, tensor }
|
||||
*/
|
||||
export async function uploadWithTensor(
|
||||
file: File,
|
||||
): Promise<{ uploadResponse: UploadResponse; tensor: Float32Array }> {
|
||||
const { resizeImage, imageToTensor } = await import("@/lib/image-processing");
|
||||
|
||||
const tensor = imageToTensor(await resizeImage(file));
|
||||
const uploadResponse = await uploadImage(file);
|
||||
|
||||
return { uploadResponse, tensor };
|
||||
}
|
||||
173
src/lib/constants.test.ts
Normal file
173
src/lib/constants.test.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import {
|
||||
APP_NAME,
|
||||
NAV_LINKS,
|
||||
PLANT_CATEGORIES,
|
||||
TRUST_SIGNALS,
|
||||
HOW_IT_WORKS,
|
||||
BETA_DISCLAIMER,
|
||||
SOCIAL_LINKS,
|
||||
FEATURED_PLANT_IDS,
|
||||
APP_TAGLINE,
|
||||
APP_DESCRIPTION,
|
||||
} from "./constants";
|
||||
|
||||
describe("constants", () => {
|
||||
describe("APP_NAME", () => {
|
||||
it("is a non-empty string", () => {
|
||||
expect(typeof APP_NAME).toBe("string");
|
||||
expect(APP_NAME.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("equals expected name", () => {
|
||||
expect(APP_NAME).toBe("Plant Health ID");
|
||||
});
|
||||
});
|
||||
|
||||
describe("APP_TAGLINE", () => {
|
||||
it("is a non-empty string", () => {
|
||||
expect(typeof APP_TAGLINE).toBe("string");
|
||||
expect(APP_TAGLINE.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("APP_DESCRIPTION", () => {
|
||||
it("is a non-empty string", () => {
|
||||
expect(typeof APP_DESCRIPTION).toBe("string");
|
||||
expect(APP_DESCRIPTION.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("SOCIAL_LINKS", () => {
|
||||
it("has github link", () => {
|
||||
expect(SOCIAL_LINKS.github).toMatch(/github/);
|
||||
});
|
||||
|
||||
it("has twitter link", () => {
|
||||
expect(SOCIAL_LINKS.twitter).toMatch(/twitter/);
|
||||
});
|
||||
});
|
||||
|
||||
describe("NAV_LINKS", () => {
|
||||
it("is an array of navigation links", () => {
|
||||
expect(Array.isArray(NAV_LINKS)).toBe(true);
|
||||
expect(NAV_LINKS.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("each link has label and href", () => {
|
||||
for (const link of NAV_LINKS) {
|
||||
expect(link).toHaveProperty("label");
|
||||
expect(link).toHaveProperty("href");
|
||||
expect(typeof link.label).toBe("string");
|
||||
expect(typeof link.href).toBe("string");
|
||||
expect(link.href.startsWith("/")).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("includes expected routes", () => {
|
||||
const hrefs = NAV_LINKS.map((l) => l.href);
|
||||
expect(hrefs).toContain("/");
|
||||
expect(hrefs).toContain("/browse");
|
||||
});
|
||||
});
|
||||
|
||||
describe("PLANT_CATEGORIES", () => {
|
||||
it("is an array of categories", () => {
|
||||
expect(Array.isArray(PLANT_CATEGORIES)).toBe(true);
|
||||
expect(PLANT_CATEGORIES.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("each category has label and value", () => {
|
||||
for (const cat of PLANT_CATEGORIES) {
|
||||
expect(cat).toHaveProperty("label");
|
||||
expect(cat).toHaveProperty("value");
|
||||
expect(typeof cat.label).toBe("string");
|
||||
expect(typeof cat.value).toBe("string");
|
||||
}
|
||||
});
|
||||
|
||||
it("includes all option", () => {
|
||||
const values = PLANT_CATEGORIES.map((c) => c.value);
|
||||
expect(values).toContain("all");
|
||||
});
|
||||
|
||||
it("includes common plant categories", () => {
|
||||
const values = PLANT_CATEGORIES.map((c) => c.value);
|
||||
expect(values).toContain("vegetables");
|
||||
expect(values).toContain("flowers");
|
||||
expect(values).toContain("herbs");
|
||||
expect(values).toContain("houseplants");
|
||||
});
|
||||
});
|
||||
|
||||
describe("FEATURED_PLANT_IDS", () => {
|
||||
it("is an array of plant IDs", () => {
|
||||
expect(Array.isArray(FEATURED_PLANT_IDS)).toBe(true);
|
||||
expect(FEATURED_PLANT_IDS.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("includes expected featured plants", () => {
|
||||
expect(FEATURED_PLANT_IDS).toContain("tomato");
|
||||
expect(FEATURED_PLANT_IDS).toContain("basil");
|
||||
});
|
||||
});
|
||||
|
||||
describe("TRUST_SIGNALS", () => {
|
||||
it("is an array of trust signals", () => {
|
||||
expect(Array.isArray(TRUST_SIGNALS)).toBe(true);
|
||||
expect(TRUST_SIGNALS.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("each signal has icon and label", () => {
|
||||
for (const signal of TRUST_SIGNALS) {
|
||||
expect(signal).toHaveProperty("icon");
|
||||
expect(signal).toHaveProperty("label");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("HOW_IT_WORKS", () => {
|
||||
it("is an array of steps", () => {
|
||||
expect(Array.isArray(HOW_IT_WORKS)).toBe(true);
|
||||
expect(HOW_IT_WORKS.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("each step has step number, emoji, title, and description", () => {
|
||||
for (const step of HOW_IT_WORKS) {
|
||||
expect(step).toHaveProperty("step");
|
||||
expect(step).toHaveProperty("emoji");
|
||||
expect(step).toHaveProperty("title");
|
||||
expect(step).toHaveProperty("description");
|
||||
}
|
||||
});
|
||||
|
||||
it("has exactly 3 steps", () => {
|
||||
expect(HOW_IT_WORKS.length).toBe(3);
|
||||
});
|
||||
|
||||
it("steps are numbered sequentially", () => {
|
||||
expect(HOW_IT_WORKS[0].step).toBe(1);
|
||||
expect(HOW_IT_WORKS[1].step).toBe(2);
|
||||
expect(HOW_IT_WORKS[2].step).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe("BETA_DISCLAIMER", () => {
|
||||
it("is a non-empty string", () => {
|
||||
expect(typeof BETA_DISCLAIMER).toBe("string");
|
||||
expect(BETA_DISCLAIMER.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("mentions AI-assisted tool", () => {
|
||||
expect(BETA_DISCLAIMER.toLowerCase()).toContain("ai-assisted");
|
||||
});
|
||||
|
||||
it("mentions professional advice disclaimer", () => {
|
||||
expect(BETA_DISCLAIMER.toLowerCase()).toMatch(/not a substitute|professional/i);
|
||||
});
|
||||
|
||||
it("mentions plant pathologist", () => {
|
||||
expect(BETA_DISCLAIMER.toLowerCase()).toContain("plant pathologist");
|
||||
});
|
||||
});
|
||||
});
|
||||
77
src/lib/constants.ts
Normal file
77
src/lib/constants.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Site-wide constants for the Plant Disease Identifier application.
|
||||
*/
|
||||
|
||||
export const APP_NAME = "Plant Health ID";
|
||||
export const APP_TAGLINE = "Snap. Identify. Treat.";
|
||||
export const APP_DESCRIPTION =
|
||||
"Upload a plant photo for hyper-specific disease diagnosis with confidence scores, symptoms, causes, treatment steps, and prevention tips.";
|
||||
|
||||
export const SOCIAL_LINKS = {
|
||||
github: "https://github.com/plant-health-id",
|
||||
twitter: "https://twitter.com/planthealthid",
|
||||
} as const;
|
||||
|
||||
export const NAV_LINKS = [
|
||||
{ href: "/", label: "Home" },
|
||||
{ href: "/upload", label: "Identify" },
|
||||
{ href: "/browse", label: "Browse Plants" },
|
||||
{ href: "/about", label: "About" },
|
||||
] as const;
|
||||
|
||||
export const PLANT_CATEGORIES = [
|
||||
{ value: "all", label: "All" },
|
||||
{ value: "vegetable", label: "Vegetables" },
|
||||
{ value: "herb", label: "Herbs" },
|
||||
{ value: "houseplant", label: "Houseplants" },
|
||||
{ value: "flower", label: "Flowers" },
|
||||
{ value: "succulent", label: "Succulents" },
|
||||
{ value: "fruit", label: "Fruits" },
|
||||
] as const;
|
||||
|
||||
export const FEATURED_PLANT_IDS = [
|
||||
"tomato",
|
||||
"basil",
|
||||
"rose",
|
||||
"monstera",
|
||||
"snake-plant",
|
||||
"pepper",
|
||||
"apple",
|
||||
"corn",
|
||||
"wheat",
|
||||
"strawberry",
|
||||
"blueberry",
|
||||
"lettuce",
|
||||
] as const;
|
||||
|
||||
export const TRUST_SIGNALS = [
|
||||
{ icon: "📸", label: "Trained on 500K+ images" },
|
||||
{ icon: "🌿", label: "Covers 300+ plants with 10K+ diseases" },
|
||||
{ icon: "🔓", label: "Open source" },
|
||||
] as const;
|
||||
|
||||
export const HOW_IT_WORKS = [
|
||||
{
|
||||
step: 1,
|
||||
emoji: "📸",
|
||||
title: "Upload a Photo",
|
||||
description: "Snap a picture of the affected plant leaf or fruit with your phone or camera.",
|
||||
},
|
||||
{
|
||||
step: 2,
|
||||
emoji: "🧠",
|
||||
title: "AI Analysis",
|
||||
description:
|
||||
"Our model analyzes the image against 500K+ labeled plant disease images in seconds.",
|
||||
},
|
||||
{
|
||||
step: 3,
|
||||
emoji: "🌱",
|
||||
title: "Get Treatment Plan",
|
||||
description:
|
||||
"Receive a detailed diagnosis with confidence score, symptoms, causes, and step-by-step treatment.",
|
||||
},
|
||||
] as const;
|
||||
|
||||
export const BETA_DISCLAIMER =
|
||||
"Plant Health ID is an AI-assisted tool for informational purposes only. It is not a substitute for professional agricultural or horticultural advice. Always consult a certified plant pathologist or extension service for critical plant health decisions.";
|
||||
372
src/lib/db.ts
Normal file
372
src/lib/db.ts
Normal file
@@ -0,0 +1,372 @@
|
||||
/**
|
||||
* Turso/libSQL Database Client
|
||||
*
|
||||
* Provides the database client and schema management for the plant disease
|
||||
* knowledge base. Connect to Turso/libSQL using environment variables.
|
||||
*
|
||||
* Required env vars:
|
||||
* DATABASE_URL — Turso database URL (e.g., libsql://my-db.turso.io)
|
||||
* DATABASE_TOKEN — Turso authentication token
|
||||
*/
|
||||
|
||||
import { createClient, type InValue } from "@libsql/client";
|
||||
import type { Plant, Disease, CausalAgentType, Prevalence, Severity } from "./types";
|
||||
|
||||
// ─── Client ──────────────────────────────────────────────────────────────────
|
||||
|
||||
let client: ReturnType<typeof createClient> | null = null;
|
||||
let connected = false;
|
||||
|
||||
/** Get or create a singleton database client */
|
||||
export function getDb() {
|
||||
if (client) return client;
|
||||
|
||||
const url = process.env.DATABASE_URL;
|
||||
const token = process.env.DATABASE_TOKEN;
|
||||
|
||||
if (!url) {
|
||||
throw new Error(
|
||||
"DATABASE_URL is not set. Check your .env.development or .env.production file.",
|
||||
);
|
||||
}
|
||||
if (!token) {
|
||||
throw new Error(
|
||||
"DATABASE_TOKEN is not set. Check your .env.development or .env.production file.",
|
||||
);
|
||||
}
|
||||
|
||||
client = createClient({ url, authToken: token });
|
||||
return client;
|
||||
}
|
||||
|
||||
/** Check database connectivity */
|
||||
export async function checkConnection(): Promise<boolean> {
|
||||
try {
|
||||
const db = getDb();
|
||||
await db.execute("SELECT 1 AS ok");
|
||||
connected = true;
|
||||
return true;
|
||||
} catch (err) {
|
||||
connected = false;
|
||||
console.error("[DB] Connection failed:", err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function isConnected() {
|
||||
return connected;
|
||||
}
|
||||
|
||||
// ─── Schema ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/** SQL to create the plants table */
|
||||
const PLANTS_TABLE_SQL = `
|
||||
CREATE TABLE IF NOT EXISTS plants (
|
||||
id TEXT PRIMARY KEY,
|
||||
common_name TEXT NOT NULL,
|
||||
scientific_name TEXT NOT NULL,
|
||||
family TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
care_summary TEXT NOT NULL DEFAULT '',
|
||||
image_url TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_plants_category ON plants(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_plants_common_name ON plants(common_name);
|
||||
`;
|
||||
|
||||
/** SQL to create the diseases table */
|
||||
const DISEASES_TABLE_SQL = `
|
||||
CREATE TABLE IF NOT EXISTS diseases (
|
||||
id TEXT PRIMARY KEY,
|
||||
plant_id TEXT NOT NULL REFERENCES plants(id),
|
||||
name TEXT NOT NULL,
|
||||
scientific_name TEXT NOT NULL DEFAULT '',
|
||||
causal_agent_type TEXT NOT NULL CHECK (causal_agent_type IN ('fungal','bacterial','viral','environmental')),
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
symptoms TEXT NOT NULL DEFAULT '[]',
|
||||
causes TEXT NOT NULL DEFAULT '[]',
|
||||
treatment TEXT NOT NULL DEFAULT '[]',
|
||||
prevention TEXT NOT NULL DEFAULT '[]',
|
||||
lookalike_ids TEXT NOT NULL DEFAULT '[]',
|
||||
severity TEXT NOT NULL CHECK (severity IN ('low','moderate','high','critical')),
|
||||
source_url TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_diseases_plant_id ON diseases(plant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_diseases_causal_agent ON diseases(causal_agent_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_diseases_severity ON diseases(severity);
|
||||
|
||||
-- Full-text search virtual table for diseases
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS diseases_fts USING fts5(
|
||||
name,
|
||||
scientific_name,
|
||||
description,
|
||||
symptoms_text,
|
||||
content='diseases',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
`;
|
||||
|
||||
/** SQL to create the scrape_log table for tracking source freshness */
|
||||
const SCRAPE_LOG_SQL = `
|
||||
CREATE TABLE IF NOT EXISTS scrape_sources (
|
||||
id TEXT PRIMARY KEY,
|
||||
source_type TEXT NOT NULL CHECK (source_type IN ('wikipedia','university_extension','cabi','other')),
|
||||
source_url TEXT NOT NULL,
|
||||
last_scraped_at TEXT,
|
||||
entries_count INTEGER DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending','success','error')),
|
||||
error_message TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
`;
|
||||
|
||||
/** Run all schema migrations */
|
||||
export async function runSchema() {
|
||||
const db = getDb();
|
||||
console.log("[DB] Running schema migrations...");
|
||||
|
||||
await db.execute(PLANTS_TABLE_SQL);
|
||||
console.log("[DB] ✓ plants table");
|
||||
|
||||
await db.execute(DISEASES_TABLE_SQL);
|
||||
console.log("[DB] ✓ diseases table");
|
||||
|
||||
await db.execute(SCRAPE_LOG_SQL);
|
||||
console.log("[DB] ✓ scrape_sources table");
|
||||
|
||||
console.log("[DB] Schema up to date.");
|
||||
}
|
||||
|
||||
// ─── Row ↔ Type mappers ─────────────────────────────────────────────────────
|
||||
|
||||
/** Convert a database row to a Plant object */
|
||||
export function rowToPlant(row: Record<string, unknown>): Plant {
|
||||
return {
|
||||
id: row.id as string,
|
||||
commonName: row.common_name as string,
|
||||
scientificName: row.scientific_name as string,
|
||||
family: row.family as string,
|
||||
category: row.category as Plant["category"],
|
||||
careSummary: row.care_summary as string,
|
||||
imageUrl: row.image_url as string,
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert a database row to a Disease object */
|
||||
export function rowToDisease(row: Record<string, unknown>): Disease {
|
||||
return {
|
||||
id: row.id as string,
|
||||
plantId: row.plant_id as string,
|
||||
name: row.name as string,
|
||||
scientificName: row.scientific_name as string,
|
||||
causalAgentType: row.causal_agent_type as CausalAgentType,
|
||||
description: row.description as string,
|
||||
symptoms: JSON.parse(row.symptoms as string) as string[],
|
||||
causes: JSON.parse(row.causes as string) as string[],
|
||||
treatment: JSON.parse(row.treatment as string) as string[],
|
||||
prevention: JSON.parse(row.prevention as string) as string[],
|
||||
lookalikeDiseaseIds: JSON.parse(row.lookalike_ids as string) as string[],
|
||||
severity: row.severity as Severity,
|
||||
prevalence: (row.prevalence as Prevalence) ?? "uncommon",
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert a Plant object to database column values */
|
||||
export function plantToRow(plant: Plant): Record<string, InValue> {
|
||||
return {
|
||||
id: plant.id,
|
||||
common_name: plant.commonName,
|
||||
scientific_name: plant.scientificName,
|
||||
family: plant.family,
|
||||
category: plant.category,
|
||||
care_summary: plant.careSummary,
|
||||
image_url: plant.imageUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert a Disease object to database column values */
|
||||
export function diseaseToRow(disease: Disease & { sourceUrl?: string }): Record<string, InValue> {
|
||||
return {
|
||||
id: disease.id,
|
||||
plant_id: disease.plantId,
|
||||
name: disease.name,
|
||||
scientific_name: disease.scientificName,
|
||||
causal_agent_type: disease.causalAgentType,
|
||||
description: disease.description,
|
||||
symptoms: JSON.stringify(disease.symptoms),
|
||||
causes: JSON.stringify(disease.causes),
|
||||
treatment: JSON.stringify(disease.treatment),
|
||||
prevention: JSON.stringify(disease.prevention),
|
||||
lookalike_ids: JSON.stringify(disease.lookalikeDiseaseIds),
|
||||
severity: disease.severity,
|
||||
source_url: disease.sourceUrl ?? "",
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Query helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
/** Insert or replace a plant */
|
||||
export async function upsertPlant(plant: Plant) {
|
||||
const db = getDb();
|
||||
const row = plantToRow(plant);
|
||||
const keys = Object.keys(row);
|
||||
const columns = keys.join(", ");
|
||||
const placeholders = keys.map(() => "?").join(", ");
|
||||
const values = keys.map((k) => row[k]);
|
||||
|
||||
await db.execute(`INSERT OR REPLACE INTO plants (${columns}) VALUES (${placeholders})`, values);
|
||||
}
|
||||
|
||||
/** Insert or replace a disease */
|
||||
export async function upsertDisease(disease: Disease & { sourceUrl?: string }) {
|
||||
const db = getDb();
|
||||
const row = diseaseToRow(disease);
|
||||
const keys = Object.keys(row);
|
||||
const columns = keys.join(", ");
|
||||
const placeholders = keys.map(() => "?").join(", ");
|
||||
const values = keys.map((k) => row[k]);
|
||||
|
||||
await db.execute(`INSERT OR REPLACE INTO diseases (${columns}) VALUES (${placeholders})`, values);
|
||||
}
|
||||
|
||||
/** Bulk insert plants in a transaction */
|
||||
export async function bulkUpsertPlants(plants: Plant[]) {
|
||||
const db = getDb();
|
||||
const tx = await db.transaction("write");
|
||||
try {
|
||||
for (const plant of plants) {
|
||||
const row = plantToRow(plant);
|
||||
const keys = Object.keys(row);
|
||||
const columns = keys.join(", ");
|
||||
const placeholders = keys.map(() => "?").join(", ");
|
||||
const values = keys.map((k) => row[k]);
|
||||
await tx.execute({
|
||||
sql: `INSERT OR REPLACE INTO plants (${columns}) VALUES (${placeholders})`,
|
||||
args: values,
|
||||
});
|
||||
}
|
||||
await tx.commit();
|
||||
console.log(`[DB] Inserted/replaced ${plants.length} plants`);
|
||||
} catch (err) {
|
||||
await tx.rollback();
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/** Bulk insert diseases in a transaction */
|
||||
export async function bulkUpsertDiseases(diseases: Array<Disease & { sourceUrl?: string }>) {
|
||||
const db = getDb();
|
||||
const tx = await db.transaction("write");
|
||||
try {
|
||||
for (const disease of diseases) {
|
||||
const row = diseaseToRow(disease);
|
||||
const keys = Object.keys(row);
|
||||
const columns = keys.join(", ");
|
||||
const placeholders = keys.map(() => "?").join(", ");
|
||||
const values = keys.map((k) => row[k]);
|
||||
await tx.execute({
|
||||
sql: `INSERT OR REPLACE INTO diseases (${columns}) VALUES (${placeholders})`,
|
||||
args: values,
|
||||
});
|
||||
}
|
||||
await tx.commit();
|
||||
console.log(`[DB] Inserted/replaced ${diseases.length} diseases`);
|
||||
} catch (err) {
|
||||
await tx.rollback();
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/** Get all plants */
|
||||
export async function getAllPlants(): Promise<Plant[]> {
|
||||
const db = getDb();
|
||||
const result = await db.execute("SELECT * FROM plants ORDER BY common_name");
|
||||
return result.rows.map((r) => rowToPlant(r as Record<string, unknown>));
|
||||
}
|
||||
|
||||
/** Get all diseases (optionally filtered by plant_id) */
|
||||
export async function getDiseases(plantId?: string): Promise<Disease[]> {
|
||||
const db = getDb();
|
||||
let sql = "SELECT * FROM diseases";
|
||||
const params: InValue[] = [];
|
||||
if (plantId) {
|
||||
sql += " WHERE plant_id = ?";
|
||||
params.push(plantId);
|
||||
}
|
||||
sql += " ORDER BY name";
|
||||
const result = await db.execute(sql, params);
|
||||
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
|
||||
}
|
||||
|
||||
/** Get a single plant by ID */
|
||||
export async function getPlantById(plantId: string): Promise<Plant | null> {
|
||||
const db = getDb();
|
||||
const result = await db.execute("SELECT * FROM plants WHERE id = ?", [plantId]);
|
||||
if (result.rows.length === 0) return null;
|
||||
return rowToPlant(result.rows[0] as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/** Get a single disease by ID */
|
||||
export async function getDiseaseById(diseaseId: string): Promise<Disease | null> {
|
||||
const db = getDb();
|
||||
const result = await db.execute("SELECT * FROM diseases WHERE id = ?", [diseaseId]);
|
||||
if (result.rows.length === 0) return null;
|
||||
return rowToDisease(result.rows[0] as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/** Search diseases via FTS */
|
||||
export async function searchDiseasesFts(searchTerm: string): Promise<Disease[]> {
|
||||
const db = getDb();
|
||||
try {
|
||||
const result = await db.execute(
|
||||
`SELECT d.* FROM diseases d
|
||||
JOIN diseases_fts fts ON d.rowid = fts.rowid
|
||||
WHERE diseases_fts MATCH ?
|
||||
ORDER BY rank
|
||||
LIMIT 50`,
|
||||
[searchTerm],
|
||||
);
|
||||
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
|
||||
} catch {
|
||||
// FTS might not be populated yet; fall back to LIKE search
|
||||
const likeTerm = `%${searchTerm}%`;
|
||||
const result = await db.execute(
|
||||
`SELECT * FROM diseases
|
||||
WHERE name LIKE ? OR description LIKE ? OR scientific_name LIKE ?
|
||||
LIMIT 50`,
|
||||
[likeTerm, likeTerm, likeTerm],
|
||||
);
|
||||
return result.rows.map((r) => rowToDisease(r as Record<string, unknown>));
|
||||
}
|
||||
}
|
||||
|
||||
/** Get database stats */
|
||||
export async function getDbStats(): Promise<{
|
||||
plants: number;
|
||||
diseases: number;
|
||||
byType: Record<string, number>;
|
||||
bySeverity: Record<string, number>;
|
||||
}> {
|
||||
const db = getDb();
|
||||
const plantCount = await db.execute("SELECT COUNT(*) as cnt FROM plants");
|
||||
const diseaseCount = await db.execute("SELECT COUNT(*) as cnt FROM diseases");
|
||||
const byType = await db.execute(
|
||||
"SELECT causal_agent_type, COUNT(*) as cnt FROM diseases GROUP BY causal_agent_type",
|
||||
);
|
||||
const bySeverity = await db.execute(
|
||||
"SELECT severity, COUNT(*) as cnt FROM diseases GROUP BY severity",
|
||||
);
|
||||
|
||||
return {
|
||||
plants: plantCount.rows[0].cnt as number,
|
||||
diseases: diseaseCount.rows[0].cnt as number,
|
||||
byType: Object.fromEntries(byType.rows.map((r) => [r.causal_agent_type, r.cnt as number])),
|
||||
bySeverity: Object.fromEntries(bySeverity.rows.map((r) => [r.severity, r.cnt as number])),
|
||||
};
|
||||
}
|
||||
69
src/lib/db/index.ts
Normal file
69
src/lib/db/index.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Drizzle ORM Database Client for Turso/libSQL.
|
||||
*
|
||||
* Provides the configured drizzle instance and convenience helpers.
|
||||
* Reads DATABASE_URL and DATABASE_TOKEN from environment.
|
||||
*/
|
||||
|
||||
import { sql } from "drizzle-orm";
|
||||
import { drizzle, type LibSQLDatabase } from "drizzle-orm/libsql";
|
||||
import { createClient } from "@libsql/client";
|
||||
import * as schema from "./schema";
|
||||
|
||||
export type {
|
||||
PlantRow,
|
||||
PlantInsert,
|
||||
DiseaseRow,
|
||||
DiseaseInsert,
|
||||
FlaggedContentRow,
|
||||
FlaggedContentInsert,
|
||||
} from "./schema";
|
||||
|
||||
export { schema };
|
||||
|
||||
let _db: LibSQLDatabase<typeof schema> | null = null;
|
||||
let _client: ReturnType<typeof createClient> | null = null;
|
||||
|
||||
/** Get or create the Drizzle database instance (singleton). */
|
||||
export function getDb(): LibSQLDatabase<typeof schema> {
|
||||
if (_db) return _db;
|
||||
|
||||
const url = process.env.DATABASE_URL;
|
||||
const token = process.env.DATABASE_TOKEN;
|
||||
|
||||
if (!url) {
|
||||
throw new Error(
|
||||
"DATABASE_URL is not set. Check your .env.development or .env.production file.",
|
||||
);
|
||||
}
|
||||
if (!token) {
|
||||
throw new Error(
|
||||
"DATABASE_TOKEN is not set. Check your .env.development or .env.production file.",
|
||||
);
|
||||
}
|
||||
|
||||
_client = createClient({ url, authToken: token });
|
||||
_db = drizzle(_client, { schema });
|
||||
return _db;
|
||||
}
|
||||
|
||||
/** Check database connectivity. */
|
||||
export async function checkConnection(): Promise<boolean> {
|
||||
try {
|
||||
const db = getDb();
|
||||
const result = await db.run(sql`SELECT 1 AS ok`);
|
||||
return result.rowsAffected >= 0;
|
||||
} catch (err) {
|
||||
console.error("[DB] Connection failed:", err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/** Close the client connection. */
|
||||
export function closeDb() {
|
||||
if (_client) {
|
||||
_client.close();
|
||||
_client = null;
|
||||
_db = null;
|
||||
}
|
||||
}
|
||||
173
src/lib/db/schema.ts
Normal file
173
src/lib/db/schema.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* Drizzle ORM Schema for the Plant Disease Knowledge Base.
|
||||
*
|
||||
* Uses Turso (libSQL) with SQLite dialect.
|
||||
* Arrays (symptoms, causes, treatment, prevention, lookalike_ids)
|
||||
* are stored as JSON text columns and typed via Drizzle's $type().
|
||||
*/
|
||||
|
||||
import { sql } from "drizzle-orm";
|
||||
import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core";
|
||||
|
||||
// ─── Plants Table ────────────────────────────────────────────────────────────
|
||||
|
||||
export const plants = sqliteTable(
|
||||
"plants",
|
||||
{
|
||||
id: text("id").primaryKey(),
|
||||
commonName: text("common_name").notNull(),
|
||||
scientificName: text("scientific_name").notNull(),
|
||||
family: text("family").notNull(),
|
||||
category: text("category").notNull(),
|
||||
careSummary: text("care_summary").notNull().default(""),
|
||||
imageUrl: text("image_url").notNull().default(""),
|
||||
createdAt: text("created_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
updatedAt: text("updated_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
},
|
||||
(table) => ({
|
||||
categoryIdx: index("idx_plants_category").on(table.category),
|
||||
commonNameIdx: index("idx_plants_common_name").on(table.commonName),
|
||||
}),
|
||||
);
|
||||
|
||||
// ─── Diseases Table ──────────────────────────────────────────────────────────
|
||||
|
||||
export const diseases = sqliteTable(
|
||||
"diseases",
|
||||
{
|
||||
id: text("id").primaryKey(),
|
||||
plantId: text("plant_id")
|
||||
.notNull()
|
||||
.references(() => plants.id),
|
||||
name: text("name").notNull(),
|
||||
scientificName: text("scientific_name").notNull().default(""),
|
||||
causalAgentType: text("causal_agent_type", {
|
||||
enum: ["fungal", "bacterial", "viral", "environmental"],
|
||||
}).notNull(),
|
||||
description: text("description").notNull().default(""),
|
||||
symptoms: text("symptoms", { mode: "json" }).notNull().default([]).$type<string[]>(),
|
||||
causes: text("causes", { mode: "json" }).notNull().default([]).$type<string[]>(),
|
||||
treatment: text("treatment", { mode: "json" }).notNull().default([]).$type<string[]>(),
|
||||
prevention: text("prevention", { mode: "json" }).notNull().default([]).$type<string[]>(),
|
||||
lookalikeIds: text("lookalike_ids", { mode: "json" }).notNull().default([]).$type<string[]>(),
|
||||
prevalence: text("prevalence", {
|
||||
enum: ["common", "uncommon", "rare", "very_rare"],
|
||||
})
|
||||
.notNull()
|
||||
.default("uncommon"),
|
||||
prevalenceScore: integer("prevalence_score").notNull().default(0),
|
||||
severity: text("severity", {
|
||||
enum: ["low", "moderate", "high", "critical"],
|
||||
}).notNull(),
|
||||
imageUrl: text("image_url").notNull().default(""),
|
||||
sourceUrl: text("source_url").notNull().default(""),
|
||||
createdAt: text("created_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
updatedAt: text("updated_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
},
|
||||
(table) => ({
|
||||
plantIdIdx: index("idx_diseases_plant_id").on(table.plantId),
|
||||
causalAgentIdx: index("idx_diseases_causal_agent").on(table.causalAgentType),
|
||||
severityIdx: index("idx_diseases_severity").on(table.severity),
|
||||
prevalenceIdx: index("idx_diseases_prevalence").on(table.prevalence),
|
||||
}),
|
||||
);
|
||||
|
||||
// ─── Scrape Sources Table ────────────────────────────────────────────────────
|
||||
|
||||
export const scrapeSources = sqliteTable("scrape_sources", {
|
||||
id: text("id").primaryKey(),
|
||||
sourceType: text("source_type", {
|
||||
enum: ["wikipedia", "university_extension", "cabi", "other"],
|
||||
}).notNull(),
|
||||
sourceUrl: text("source_url").notNull(),
|
||||
lastScrapedAt: text("last_scraped_at"),
|
||||
entriesCount: integer("entries_count").default(0),
|
||||
status: text("status", { enum: ["pending", "success", "error"] })
|
||||
.notNull()
|
||||
.default("pending"),
|
||||
errorMessage: text("error_message"),
|
||||
createdAt: text("created_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
});
|
||||
|
||||
// ─── Plant Views Table ───────────────────────────────────────────────────────
|
||||
|
||||
export const plantViews = sqliteTable(
|
||||
"plant_views",
|
||||
{
|
||||
plantId: text("plant_id")
|
||||
.primaryKey()
|
||||
.references(() => plants.id),
|
||||
viewCount: integer("view_count").notNull().default(0),
|
||||
},
|
||||
(table) => ({
|
||||
viewCountIdx: index("idx_plant_views_count").on(table.viewCount),
|
||||
}),
|
||||
);
|
||||
|
||||
// ─── Flagged Content Table ─────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Stores user-flagged content for manual review.
|
||||
* content_type: what kind of content is flagged
|
||||
* content_id: the ID of the plant or disease
|
||||
* field_name: specific field being flagged (e.g., "image", "symptoms", "causes", "treatment", "prevention")
|
||||
* flag_count: number of times this item has been flagged
|
||||
*/
|
||||
export const flaggedContent = sqliteTable(
|
||||
"flagged_content",
|
||||
{
|
||||
id: text("id").primaryKey(),
|
||||
contentType: text("content_type", {
|
||||
enum: [
|
||||
"plant_image",
|
||||
"disease_image",
|
||||
"disease_description",
|
||||
"disease_symptoms",
|
||||
"disease_causes",
|
||||
"disease_treatment",
|
||||
"disease_prevention",
|
||||
],
|
||||
}).notNull(),
|
||||
contentId: text("content_id").notNull(),
|
||||
fieldName: text("field_name").notNull(),
|
||||
notes: text("notes").default(""),
|
||||
flagCount: integer("flag_count").notNull().default(1),
|
||||
createdAt: text("created_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
updatedAt: text("updated_at")
|
||||
.notNull()
|
||||
.default(sql`(datetime('now'))`),
|
||||
},
|
||||
(table) => ({
|
||||
contentTypeIdx: index("idx_flagged_content_type").on(table.contentType),
|
||||
contentIdIdx: index("idx_flagged_content_id").on(table.contentId),
|
||||
}),
|
||||
);
|
||||
|
||||
// ─── Type helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
export type FlaggedContentRow = typeof flaggedContent.$inferSelect;
|
||||
export type FlaggedContentInsert = typeof flaggedContent.$inferInsert;
|
||||
|
||||
// ─── Relation Inference ──────────────────────────────────────────────────────
|
||||
|
||||
export const plantsRelations = {};
|
||||
export const diseasesRelations = {};
|
||||
|
||||
// ─── Type helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
export type PlantRow = typeof plants.$inferSelect;
|
||||
export type PlantInsert = typeof plants.$inferInsert;
|
||||
export type DiseaseRow = typeof diseases.$inferSelect;
|
||||
export type DiseaseInsert = typeof diseases.$inferInsert;
|
||||
47
src/lib/display-helpers.ts
Normal file
47
src/lib/display-helpers.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Display helpers for the browse UI that bridge the DB types
|
||||
* to display-friendly values (emoji icons, descriptions).
|
||||
*/
|
||||
|
||||
const CATEGORY_EMOJIS: Record<string, string> = {
|
||||
vegetable: "🥬",
|
||||
fruit: "🍎",
|
||||
herb: "🌿",
|
||||
flower: "🌸",
|
||||
houseplant: "🪴",
|
||||
succulent: "🌵",
|
||||
tree: "🌳",
|
||||
};
|
||||
|
||||
const FALLBACK_EMOJI = "🌱";
|
||||
|
||||
export function getEmojiForCategory(category: string): string {
|
||||
return CATEGORY_EMOJIS[category] ?? FALLBACK_EMOJI;
|
||||
}
|
||||
|
||||
export function getPlantDescription(
|
||||
commonName: string,
|
||||
scientificName: string,
|
||||
category: string,
|
||||
family: string,
|
||||
): string {
|
||||
return `${commonName} (${scientificName}) is a ${category} in the ${family} family. Preventative care and early identification of diseases are key to keeping your ${commonName.toLowerCase()} healthy.`;
|
||||
}
|
||||
|
||||
export function getDescriptionForCategory(category: string): string {
|
||||
const descriptions: Record<string, string> = {
|
||||
vegetable:
|
||||
"Vegetables are garden favorites grown for their edible parts. They can be affected by various fungal, bacterial, and viral diseases that impact yield and quality.",
|
||||
fruit:
|
||||
"Fruit plants produce delicious harvests but require attention to disease management for optimal production.",
|
||||
herb: "Herbs are aromatic plants used in cooking and medicine. Most are relatively disease-resistant but can be affected in humid conditions.",
|
||||
flower:
|
||||
"Ornamental flowers add beauty to gardens. They may be susceptible to various foliar and root diseases.",
|
||||
houseplant:
|
||||
"Houseplants bring nature indoors. The most common issues are overwatering, insufficient light, and fungal leaf spots.",
|
||||
succulent:
|
||||
"Succulents store water in their leaves and stems. Overwatering is the most common cause of problems.",
|
||||
tree: "Trees provide shade, fruit, and beauty. They can be affected by cankers, rots, wilts, and foliar diseases.",
|
||||
};
|
||||
return descriptions[category] ?? `This plant belongs to the ${category} category.`;
|
||||
}
|
||||
241
src/lib/image-processing.test.ts
Normal file
241
src/lib/image-processing.test.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
/**
|
||||
* Unit tests for lib/image-processing.ts
|
||||
*
|
||||
* Tests:
|
||||
* - resizeImage() produces 224×224 output for any input aspect ratio
|
||||
* - imageToTensor() output length equals 3 * 224 * 224
|
||||
* - Normalization produces values in expected range
|
||||
* - validateImageFile rejects invalid types and oversized files
|
||||
* - tensorToBase64 / base64ToTensor round-trip
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import {
|
||||
resizeImage,
|
||||
imageToTensor,
|
||||
tensorToBase64,
|
||||
base64ToTensor,
|
||||
getTensorShape,
|
||||
validateImageFile,
|
||||
MAX_FILE_SIZE,
|
||||
MIN_DIMENSION,
|
||||
ALLOWED_MIME_TYPES,
|
||||
} from "@/lib/image-processing";
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Create a mock ImageData at given dimensions */
|
||||
function createMockImageData(
|
||||
width: number,
|
||||
height: number,
|
||||
fillR = 128,
|
||||
fillG = 64,
|
||||
fillB = 32,
|
||||
): ImageData {
|
||||
const data = new Uint8ClampedArray(width * height * 4);
|
||||
for (let i = 0; i < width * height; i++) {
|
||||
data[i * 4] = fillR;
|
||||
data[i * 4 + 1] = fillG;
|
||||
data[i * 4 + 2] = fillB;
|
||||
data[i * 4 + 3] = 255;
|
||||
}
|
||||
return { width, height, data };
|
||||
}
|
||||
|
||||
/** Create a mock File with given properties */
|
||||
function createMockFile({
|
||||
name = "test.jpg",
|
||||
type = "image/jpeg",
|
||||
size = 1024,
|
||||
}: Partial<File> & Pick<File, "name" | "type" | "size"> = {}): File {
|
||||
// Use ArrayBuffer to control actual file size for large/empty tests
|
||||
let content: BlobPart;
|
||||
if (size === 0) {
|
||||
content = "";
|
||||
} else if (size > 1024) {
|
||||
// For large files, create a buffer of the right size
|
||||
content = new Uint8Array(size);
|
||||
} else {
|
||||
content = "dummy";
|
||||
}
|
||||
return new File([content], name, { type });
|
||||
}
|
||||
|
||||
// ─── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("validateImageFile", () => {
|
||||
it("accepts valid JPEG", () => {
|
||||
const file = createMockFile({ name: "photo.jpg", type: "image/jpeg", size: 1024 });
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts valid PNG", () => {
|
||||
const file = createMockFile({ name: "photo.png", type: "image/png", size: 1024 });
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts valid WebP", () => {
|
||||
const file = createMockFile({ name: "photo.webp", type: "image/webp", size: 1024 });
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(true);
|
||||
});
|
||||
|
||||
it("rejects unsupported MIME type", () => {
|
||||
const file = createMockFile({ name: "document.txt", type: "text/plain", size: 1024 });
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(false);
|
||||
expect(result.error).toContain("Unsupported file type");
|
||||
});
|
||||
|
||||
it("rejects files larger than 10 MB", () => {
|
||||
const file = createMockFile({
|
||||
name: "huge.jpg",
|
||||
type: "image/jpeg",
|
||||
size: 11 * 1024 * 1024, // 11 MB
|
||||
});
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(false);
|
||||
expect(result.error).toContain("too large");
|
||||
});
|
||||
|
||||
it("rejects empty files", () => {
|
||||
const file = createMockFile({ name: "empty.jpg", type: "image/jpeg", size: 0 });
|
||||
const result = validateImageFile(file);
|
||||
expect(result.ok).toBe(false);
|
||||
expect(result.error).toContain("empty");
|
||||
});
|
||||
});
|
||||
|
||||
describe("imageToTensor", () => {
|
||||
it("produces correct tensor length for 224×224", () => {
|
||||
const imageData = createMockImageData(224, 224);
|
||||
const tensor = imageToTensor(imageData);
|
||||
expect(tensor.length).toBe(3 * 224 * 224);
|
||||
});
|
||||
|
||||
it("produces correct tensor length for 299×299", () => {
|
||||
const imageData = createMockImageData(299, 299);
|
||||
const tensor = imageToTensor(imageData);
|
||||
expect(tensor.length).toBe(3 * 299 * 299);
|
||||
});
|
||||
|
||||
it("produces Float32Array", () => {
|
||||
const imageData = createMockImageData(224, 224);
|
||||
const tensor = imageToTensor(imageData);
|
||||
expect(tensor).toBeInstanceOf(Float32Array);
|
||||
});
|
||||
|
||||
it("normalizes pixel values with ImageNet mean/std", () => {
|
||||
// All pixels set to 128/255 = 0.502
|
||||
const imageData = createMockImageData(224, 224, 128, 128, 128);
|
||||
const tensor = imageToTensor(imageData);
|
||||
|
||||
// With ImageNet mean [0.485, 0.456, 0.406] and std [0.229, 0.224, 0.225]
|
||||
// R: (0.502 - 0.485) / 0.229 ≈ 0.074
|
||||
const expectedR = (128 / 255 - 0.485) / 0.229;
|
||||
expect(tensor[0]).toBeCloseTo(expectedR, 3);
|
||||
|
||||
// G: (0.502 - 0.456) / 0.224 ≈ 0.205
|
||||
const totalPixels = 224 * 224;
|
||||
const expectedG = (128 / 255 - 0.456) / 0.224;
|
||||
expect(tensor[totalPixels]).toBeCloseTo(expectedG, 3);
|
||||
|
||||
// B: (0.502 - 0.406) / 0.225 ≈ 0.427
|
||||
const expectedB = (128 / 255 - 0.406) / 0.225;
|
||||
expect(tensor[2 * totalPixels]).toBeCloseTo(expectedB, 3);
|
||||
});
|
||||
|
||||
it("preserves channel separation (R, G, B in separate channels)", () => {
|
||||
// R=255, G=0, B=0 (pure red)
|
||||
const imageData = createMockImageData(224, 224, 255, 0, 0);
|
||||
const tensor = imageToTensor(imageData);
|
||||
const totalPixels = 224 * 224;
|
||||
|
||||
// After ImageNet normalization:
|
||||
// R: (1.0 - 0.485) / 0.229 ≈ 2.25
|
||||
// G: (0.0 - 0.456) / 0.224 ≈ -2.04
|
||||
// B: (0.0 - 0.406) / 0.225 ≈ -1.80
|
||||
const rVal = tensor[0];
|
||||
const gVal = tensor[totalPixels];
|
||||
const bVal = tensor[2 * totalPixels];
|
||||
|
||||
// R is positive (high), G and B are negative (low)
|
||||
expect(rVal).toBeGreaterThan(2);
|
||||
expect(gVal).toBeLessThan(-1);
|
||||
expect(bVal).toBeLessThan(-1);
|
||||
// R is highest
|
||||
expect(rVal).toBeGreaterThan(bVal);
|
||||
expect(rVal).toBeGreaterThan(gVal);
|
||||
});
|
||||
});
|
||||
|
||||
describe("tensorToBase64 / base64ToTensor", () => {
|
||||
it("round-trips tensor data correctly", () => {
|
||||
const imageData = createMockImageData(160, 160, 100, 150, 200);
|
||||
const original = imageToTensor(imageData);
|
||||
|
||||
const base64 = tensorToBase64(original);
|
||||
const decoded = base64ToTensor(base64);
|
||||
|
||||
expect(decoded.tensor.length).toBe(original.length);
|
||||
expect(decoded.shape).toEqual([3, 160, 160]);
|
||||
|
||||
// Check a few values match
|
||||
for (let i = 0; i < 10; i++) {
|
||||
expect(decoded.tensor[i]).toBeCloseTo(original[i], 5);
|
||||
}
|
||||
});
|
||||
|
||||
it("preserves custom shape", () => {
|
||||
const tensor = new Float32Array(3 * 299 * 299);
|
||||
const base64 = tensorToBase64(tensor, [3, 299, 299]);
|
||||
const decoded = base64ToTensor(base64);
|
||||
expect(decoded.shape).toEqual([3, 299, 299]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getTensorShape", () => {
|
||||
it("returns [1, 3, 160, 160] by default", () => {
|
||||
const shape = getTensorShape();
|
||||
expect(shape).toEqual([1, 3, 160, 160]);
|
||||
});
|
||||
|
||||
it("returns NCHW layout", () => {
|
||||
const shape = getTensorShape();
|
||||
expect(shape.length).toBe(4);
|
||||
expect(shape[0]).toBe(1); // batch
|
||||
expect(shape[1]).toBe(3); // channels
|
||||
expect(shape[2]).toBe(160); // height (model input size)
|
||||
expect(shape[3]).toBe(160); // width (model input size)
|
||||
});
|
||||
});
|
||||
|
||||
describe("resizeImage", () => {
|
||||
it("is an async function", () => {
|
||||
expect(typeof resizeImage).toBe("function");
|
||||
});
|
||||
|
||||
it("accepts a File and size parameter", () => {
|
||||
const file = createMockFile({ name: "test.jpg", type: "image/jpeg", size: 1024 });
|
||||
// In jsdom, this will use the mock canvas — we verify the function signature
|
||||
expect(resizeImage).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("constants", () => {
|
||||
it("MAX_FILE_SIZE is 10 MB", () => {
|
||||
expect(MAX_FILE_SIZE).toBe(10 * 1024 * 1024);
|
||||
});
|
||||
|
||||
it("MIN_DIMENSION is 150", () => {
|
||||
expect(MIN_DIMENSION).toBe(150);
|
||||
});
|
||||
|
||||
it("ALLOWED_MIME_TYPES includes PNG, JPEG, and WebP", () => {
|
||||
expect(ALLOWED_MIME_TYPES).toContain("image/png");
|
||||
expect(ALLOWED_MIME_TYPES).toContain("image/jpeg");
|
||||
expect(ALLOWED_MIME_TYPES).toContain("image/webp");
|
||||
});
|
||||
});
|
||||
244
src/lib/image-processing.ts
Normal file
244
src/lib/image-processing.ts
Normal file
@@ -0,0 +1,244 @@
|
||||
/**
|
||||
* Client-side image preprocessing pipeline.
|
||||
*
|
||||
* Resizes images to model-expected dimensions (160×160 by default),
|
||||
* converts RGBA → RGB, normalizes pixel values, and produces flat
|
||||
* Float32Array tensors ready for ML inference or base64 transmission.
|
||||
*
|
||||
* Tensor shape: [1, 3, 160, 160] — NCHW layout matching MobileNetV2.
|
||||
*
|
||||
* Configurable via env:
|
||||
* IMAGE_MODEL_SIZE — target dimension (default 160)
|
||||
* IMAGE_MEAN_R/G/B — per-channel mean for normalization (default 0.485, 0.456, 0.406 — ImageNet)
|
||||
* IMAGE_STD_R/G/B — per-channel std for normalization (default 0.229, 0.224, 0.225 — ImageNet)
|
||||
*/
|
||||
|
||||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
const DEFAULT_MODEL_SIZE = 160;
|
||||
const DEFAULT_MEAN = [0.485, 0.456, 0.406] as const; // ImageNet RGB means
|
||||
const DEFAULT_STD = [0.229, 0.224, 0.225] as const; // ImageNet RGB stds
|
||||
|
||||
function getConfig(): {
|
||||
size: number;
|
||||
mean: readonly [number, number, number];
|
||||
std: readonly [number, number, number];
|
||||
} {
|
||||
// These env vars are exposed via next.config.ts / .env.local
|
||||
const size = parseInt(
|
||||
typeof process !== "undefined" && process.env?.IMAGE_MODEL_SIZE
|
||||
? process.env.IMAGE_MODEL_SIZE
|
||||
: String(DEFAULT_MODEL_SIZE),
|
||||
10,
|
||||
);
|
||||
|
||||
return {
|
||||
size: isNaN(size) ? DEFAULT_MODEL_SIZE : size,
|
||||
mean: DEFAULT_MEAN,
|
||||
std: DEFAULT_STD,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
/** Maximum file size accepted (10 MB) */
|
||||
export const MAX_FILE_SIZE = 10 * 1024 * 1024;
|
||||
|
||||
/** Minimum image dimensions (150×150) */
|
||||
export const MIN_DIMENSION = 150;
|
||||
|
||||
/** Allowed MIME types */
|
||||
export const ALLOWED_MIME_TYPES = ["image/png", "image/jpeg", "image/jpg", "image/webp"] as const;
|
||||
|
||||
export type AllowedMimeType = (typeof ALLOWED_MIME_TYPES)[number];
|
||||
|
||||
/** Maximum number of ephemeral uploads to keep */
|
||||
export const MAX_UPLOADS = 100;
|
||||
|
||||
// ─── Validation ──────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Validate that a file is an acceptable image for upload.
|
||||
* Returns `{ ok: true }` or `{ ok: false, error: string }`.
|
||||
*/
|
||||
export function validateImageFile(file: File): { ok: true } | { ok: false; error: string } {
|
||||
// MIME type check
|
||||
if (!ALLOWED_MIME_TYPES.includes(file.type as AllowedMimeType)) {
|
||||
return {
|
||||
ok: false,
|
||||
error: `Unsupported file type "${file.type}". Allowed: PNG, JPG, WebP.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Size check
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
const mb = (file.size / (1024 * 1024)).toFixed(1);
|
||||
return { ok: false, error: `File too large (${mb} MB). Maximum is 10 MB.` };
|
||||
}
|
||||
|
||||
// Zero-size check
|
||||
if (file.size === 0) {
|
||||
return { ok: false, error: "File is empty." };
|
||||
}
|
||||
|
||||
return { ok: true };
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that an image file meets minimum dimension requirements.
|
||||
* Returns a promise resolving to `{ ok: true }` or `{ ok: false, error: string }`.
|
||||
*/
|
||||
export function validateImageDimensions(
|
||||
file: File,
|
||||
): Promise<{ ok: true } | { ok: false; error: string }> {
|
||||
return new Promise((resolve) => {
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
if (img.width < MIN_DIMENSION || img.height < MIN_DIMENSION) {
|
||||
resolve({
|
||||
ok: false,
|
||||
error: `Image too small (${img.width}×${img.height}). Minimum is ${MIN_DIMENSION}×${MIN_DIMENSION}.`,
|
||||
});
|
||||
} else {
|
||||
resolve({ ok: true });
|
||||
}
|
||||
};
|
||||
img.onerror = () => {
|
||||
resolve({ ok: false, error: "Failed to read image dimensions." });
|
||||
};
|
||||
img.src = URL.createObjectURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Resize ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Resize an image file to the target model size using an offscreen canvas.
|
||||
* Uses bilinear interpolation via canvas drawing.
|
||||
*
|
||||
* @param file - Source image file
|
||||
* @param size - Target dimension (square). Defaults to IMAGE_MODEL_SIZE env or 224.
|
||||
* @returns ImageData at exactly `size × size`
|
||||
*/
|
||||
export async function resizeImage(file: File, size: number = getConfig().size): Promise<ImageData> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement("canvas");
|
||||
canvas.width = size;
|
||||
canvas.height = size;
|
||||
const ctx = canvas.getContext("2d");
|
||||
if (!ctx) {
|
||||
reject(new Error("Could not get canvas 2D context."));
|
||||
return;
|
||||
}
|
||||
// Bilinear resize via drawImage
|
||||
ctx.imageSmoothingEnabled = true;
|
||||
ctx.imageSmoothingQuality = "high";
|
||||
ctx.drawImage(img, 0, 0, size, size);
|
||||
const imageData = ctx.getImageData(0, 0, size, size);
|
||||
URL.revokeObjectURL(img.src);
|
||||
resolve(imageData);
|
||||
};
|
||||
img.onerror = () => {
|
||||
URL.revokeObjectURL(img.src);
|
||||
reject(new Error("Failed to load image for resizing."));
|
||||
};
|
||||
img.src = URL.createObjectURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Tensor Conversion ───────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Convert ImageData (RGBA) to a flat Float32Array tensor in RGB layout.
|
||||
* Drops the alpha channel, normalizes pixel values to [0, 1].
|
||||
*
|
||||
* Output layout: flat array of length 3 × width × height.
|
||||
* Channel order: RRR...GGG...BBB... (channel-first, like PyTorch NCHW without batch dim).
|
||||
*
|
||||
* @param imageData - Source ImageData from resizeImage()
|
||||
* @returns Float32Array of length 3 × size × size with values in [0, 1]
|
||||
*/
|
||||
export function imageToTensor(imageData: ImageData): Float32Array {
|
||||
const { width, height, data } = imageData;
|
||||
const totalPixels = width * height;
|
||||
const config = getConfig();
|
||||
const { mean, std } = config;
|
||||
|
||||
// Allocate channel-first tensor: [3, H, W]
|
||||
const tensor = new Float32Array(3 * totalPixels);
|
||||
|
||||
// Extract R, G, B channels (skip alpha)
|
||||
const rChannel = tensor.subarray(0, totalPixels);
|
||||
const gChannel = tensor.subarray(totalPixels, 2 * totalPixels);
|
||||
const bChannel = tensor.subarray(2 * totalPixels, 3 * totalPixels);
|
||||
|
||||
for (let i = 0; i < totalPixels; i++) {
|
||||
const idx = i * 4; // RGBA stride
|
||||
rChannel[i] = data[idx] / 255;
|
||||
gChannel[i] = data[idx + 1] / 255;
|
||||
bChannel[i] = data[idx + 2] / 255;
|
||||
}
|
||||
|
||||
// Normalize with ImageNet mean/std
|
||||
for (let c = 0; c < 3; c++) {
|
||||
const channel = c === 0 ? rChannel : c === 1 ? gChannel : bChannel;
|
||||
const m = mean[c];
|
||||
const s = std[c];
|
||||
for (let i = 0; i < totalPixels; i++) {
|
||||
channel[i] = (channel[i] - m) / s;
|
||||
}
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the expected tensor shape for the current model configuration.
|
||||
* Returns [batch, channels, height, width] = [1, 3, size, size].
|
||||
*/
|
||||
export function getTensorShape(): [number, number, number, number] {
|
||||
const size = getConfig().size;
|
||||
return [1, 3, size, size];
|
||||
}
|
||||
|
||||
// ─── Base64 Encoding ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Encode a Float32Array tensor to a base64 string for transmission.
|
||||
* Wraps the binary data in a simple JSON envelope with shape metadata.
|
||||
*
|
||||
* @param tensor - Flat Float32Array from imageToTensor()
|
||||
* @param shape - Tensor shape [C, H, W], defaults to [3, size, size]
|
||||
* @returns base64-encoded JSON string
|
||||
*/
|
||||
export function tensorToBase64(
|
||||
tensor: Float32Array,
|
||||
shape: [number, number, number] = [3, getConfig().size, getConfig().size],
|
||||
): string {
|
||||
const envelope = {
|
||||
shape,
|
||||
data: Array.from(tensor),
|
||||
};
|
||||
const json = JSON.stringify(envelope);
|
||||
return btoa(json);
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode a base64 tensor string back to a Float32Array.
|
||||
*
|
||||
* @param base64 - Base64 string from tensorToBase64()
|
||||
* @returns { tensor, shape }
|
||||
*/
|
||||
export function base64ToTensor(base64: string): {
|
||||
tensor: Float32Array;
|
||||
shape: [number, number, number];
|
||||
} {
|
||||
const json = atob(base64);
|
||||
const envelope = JSON.parse(json);
|
||||
return {
|
||||
tensor: new Float32Array(envelope.data),
|
||||
shape: envelope.shape as [number, number, number],
|
||||
};
|
||||
}
|
||||
0
src/lib/ml/.gitkeep
Normal file
0
src/lib/ml/.gitkeep
Normal file
342
src/lib/ml/confidence.test.ts
Normal file
342
src/lib/ml/confidence.test.ts
Normal file
@@ -0,0 +1,342 @@
|
||||
/**
|
||||
* Unit tests for lib/ml/confidence.ts
|
||||
*
|
||||
* Tests:
|
||||
* - softmax([1, 2, 3]) sums to ~1.0
|
||||
* - softmaxFloat32 produces same results as softmax
|
||||
* - calibrateConfidence(0.9) returns label "high"
|
||||
* - calibrateConfidence(0.6) returns label "medium"
|
||||
* - calibrateConfidence(0.3) returns label "low"
|
||||
* - getTopK returns exactly 5 entries sorted descending
|
||||
* - getTopKFloat32 returns exactly 5 entries sorted descending
|
||||
* - filterByConfidence removes predictions below threshold
|
||||
* - Numerically stable softmax handles large logits
|
||||
* - Degenerate softmax (all -Infinity) returns uniform distribution
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
import {
|
||||
softmax,
|
||||
softmaxFloat32,
|
||||
calibrateConfidence,
|
||||
getConfidenceLabel,
|
||||
getTopK,
|
||||
getTopKFloat32,
|
||||
filterByConfidence,
|
||||
DEFAULT_MIN_CONFIDENCE,
|
||||
} from "@/lib/ml/confidence";
|
||||
|
||||
describe("softmax", () => {
|
||||
it("softmax([1, 2, 3]) sums to ~1.0", () => {
|
||||
const result = softmax([1, 2, 3]);
|
||||
const sum = result.reduce((a, b) => a + b, 0);
|
||||
expect(sum).toBeCloseTo(1.0, 6);
|
||||
});
|
||||
|
||||
it("produces correct probability distribution", () => {
|
||||
const result = softmax([1, 2, 3]);
|
||||
// Higher input → higher probability
|
||||
expect(result[2]).toBeGreaterThan(result[1]);
|
||||
expect(result[1]).toBeGreaterThan(result[0]);
|
||||
// All positive
|
||||
expect(result.every(v => v > 0)).toBe(true);
|
||||
});
|
||||
|
||||
it("handles equal logits uniformly", () => {
|
||||
const result = softmax([1, 1, 1]);
|
||||
expect(result[0]).toBeCloseTo(1 / 3, 6);
|
||||
expect(result[1]).toBeCloseTo(1 / 3, 6);
|
||||
expect(result[2]).toBeCloseTo(1 / 3, 6);
|
||||
});
|
||||
|
||||
it("handles single element", () => {
|
||||
const result = softmax([5]);
|
||||
expect(result).toEqual([1.0]);
|
||||
});
|
||||
|
||||
it("handles large logits without overflow", () => {
|
||||
const result = softmax([1000, 1001, 1002]);
|
||||
const sum = result.reduce((a, b) => a + b, 0);
|
||||
expect(sum).toBeCloseTo(1.0, 6);
|
||||
// The largest should dominate
|
||||
expect(result[2]).toBeGreaterThan(0.5);
|
||||
});
|
||||
|
||||
it("handles negative logits", () => {
|
||||
const result = softmax([-3, -2, -1]);
|
||||
const sum = result.reduce((a, b) => a + b, 0);
|
||||
expect(sum).toBeCloseTo(1.0, 6);
|
||||
expect(result.every(v => v > 0)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("softmaxFloat32", () => {
|
||||
it("produces Float32Array output", () => {
|
||||
const logits = new Float32Array([1, 2, 3]);
|
||||
const result = softmaxFloat32(logits);
|
||||
expect(result).toBeInstanceOf(Float32Array);
|
||||
});
|
||||
|
||||
it("sums to ~1.0", () => {
|
||||
const logits = new Float32Array([1, 2, 3]);
|
||||
const result = softmaxFloat32(logits);
|
||||
const sum = Array.from(result).reduce((a, b) => a + b, 0);
|
||||
expect(sum).toBeCloseTo(1.0, 5);
|
||||
});
|
||||
|
||||
it("matches softmax for same input", () => {
|
||||
const input = [1, 2, 3, 4, 5];
|
||||
const arrayResult = softmax(input);
|
||||
const float32Result = softmaxFloat32(new Float32Array(input));
|
||||
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
expect(float32Result[i]).toBeCloseTo(arrayResult[i], 5);
|
||||
}
|
||||
});
|
||||
|
||||
it("handles large arrays (95 classes)", () => {
|
||||
const logits = new Float32Array(95);
|
||||
for (let i = 0; i < 95; i++) {
|
||||
logits[i] = i * 0.1 - 4.75; // centered around 0
|
||||
}
|
||||
const result = softmaxFloat32(logits);
|
||||
const sum = Array.from(result).reduce((a, b) => a + b, 0);
|
||||
expect(sum).toBeCloseTo(1.0, 5);
|
||||
expect(result.length).toBe(95);
|
||||
});
|
||||
});
|
||||
|
||||
describe("calibrateConfidence", () => {
|
||||
it("calibrateConfidence(0.9) returns label 'high'", () => {
|
||||
const result = calibrateConfidence(0.9);
|
||||
expect(result.label).toBe("high");
|
||||
expect(result.raw).toBe(0.9);
|
||||
expect(result.adjusted).toBeGreaterThan(0.8);
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.95) returns label 'high'", () => {
|
||||
const result = calibrateConfidence(0.95);
|
||||
expect(result.label).toBe("high");
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.8) returns label 'high'", () => {
|
||||
const result = calibrateConfidence(0.8);
|
||||
expect(result.label).toBe("high");
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.6) returns label 'medium'", () => {
|
||||
const result = calibrateConfidence(0.6);
|
||||
expect(result.label).toBe("medium");
|
||||
expect(result.adjusted).toBeGreaterThanOrEqual(0.5);
|
||||
expect(result.adjusted).toBeLessThan(0.8);
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.55) returns label 'medium'", () => {
|
||||
const result = calibrateConfidence(0.55);
|
||||
expect(result.label).toBe("medium");
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.3) returns label 'low'", () => {
|
||||
const result = calibrateConfidence(0.3);
|
||||
expect(result.label).toBe("low");
|
||||
expect(result.adjusted).toBeLessThan(0.5);
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.1) returns label 'low'", () => {
|
||||
const result = calibrateConfidence(0.1);
|
||||
expect(result.label).toBe("low");
|
||||
});
|
||||
|
||||
it("calibrateConfidence(0.0) returns label 'low'", () => {
|
||||
const result = calibrateConfidence(0.0);
|
||||
expect(result.label).toBe("low");
|
||||
});
|
||||
|
||||
it("calibrateConfidence(1.0) returns label 'high'", () => {
|
||||
const result = calibrateConfidence(1.0);
|
||||
expect(result.label).toBe("high");
|
||||
expect(result.adjusted).toBeGreaterThan(0.9);
|
||||
});
|
||||
|
||||
it("adjusted confidence is rounded to 4 decimal places", () => {
|
||||
const result = calibrateConfidence(0.73);
|
||||
const decimalPlaces = result.adjusted.toString().split(".")[1]?.length || 0;
|
||||
expect(decimalPlaces).toBeLessThanOrEqual(4);
|
||||
});
|
||||
|
||||
it("raw confidence is rounded to 4 decimal places", () => {
|
||||
const result = calibrateConfidence(0.73456789);
|
||||
expect(result.raw).toBe(0.7346);
|
||||
});
|
||||
|
||||
it("adjusted confidence is monotonically increasing with raw", () => {
|
||||
const low = calibrateConfidence(0.3);
|
||||
const mid = calibrateConfidence(0.6);
|
||||
const high = calibrateConfidence(0.9);
|
||||
expect(high.adjusted).toBeGreaterThan(mid.adjusted);
|
||||
expect(mid.adjusted).toBeGreaterThan(low.adjusted);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getConfidenceLabel", () => {
|
||||
it("returns 'high' for score >= 0.8", () => {
|
||||
expect(getConfidenceLabel(0.8)).toBe("high");
|
||||
expect(getConfidenceLabel(0.85)).toBe("high");
|
||||
expect(getConfidenceLabel(1.0)).toBe("high");
|
||||
});
|
||||
|
||||
it("returns 'medium' for score >= 0.5 and < 0.8", () => {
|
||||
expect(getConfidenceLabel(0.5)).toBe("medium");
|
||||
expect(getConfidenceLabel(0.65)).toBe("medium");
|
||||
expect(getConfidenceLabel(0.79)).toBe("medium");
|
||||
});
|
||||
|
||||
it("returns 'low' for score < 0.5", () => {
|
||||
expect(getConfidenceLabel(0.0)).toBe("low");
|
||||
expect(getConfidenceLabel(0.49)).toBe("low");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getTopK", () => {
|
||||
it("returns exactly 5 entries by default", () => {
|
||||
const probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
|
||||
const result = getTopK(probs);
|
||||
expect(result).toHaveLength(5);
|
||||
});
|
||||
|
||||
it("returns entries sorted by probability descending", () => {
|
||||
const probs = [0.1, 0.5, 0.3, 0.9, 0.2, 0.7, 0.4];
|
||||
const result = getTopK(probs);
|
||||
for (let i = 0; i < result.length - 1; i++) {
|
||||
expect(result[i].probability).toBeGreaterThanOrEqual(result[i + 1].probability);
|
||||
}
|
||||
});
|
||||
|
||||
it("returns correct class indices", () => {
|
||||
const probs = [0.1, 0.5, 0.3, 0.9, 0.2];
|
||||
const result = getTopK(probs, 3);
|
||||
expect(result[0].classIndex).toBe(3); // 0.9
|
||||
expect(result[1].classIndex).toBe(1); // 0.5
|
||||
expect(result[2].classIndex).toBe(2); // 0.3
|
||||
});
|
||||
|
||||
it("respects custom k value", () => {
|
||||
const probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
|
||||
const result = getTopK(probs, 3);
|
||||
expect(result).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("returns all entries when k > array length", () => {
|
||||
const probs = [0.1, 0.2, 0.3];
|
||||
const result = getTopK(probs, 10);
|
||||
expect(result).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("handles equal probabilities", () => {
|
||||
const probs = [0.3, 0.3, 0.3, 0.1, 0.1];
|
||||
const result = getTopK(probs, 3);
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result.every(p => p.probability === 0.3)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getTopKFloat32", () => {
|
||||
it("returns exactly 5 entries by default", () => {
|
||||
const probs = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]);
|
||||
const result = getTopKFloat32(probs);
|
||||
expect(result).toHaveLength(5);
|
||||
});
|
||||
|
||||
it("returns entries sorted by probability descending", () => {
|
||||
const probs = new Float32Array([0.1, 0.5, 0.3, 0.9, 0.2, 0.7, 0.4]);
|
||||
const result = getTopKFloat32(probs);
|
||||
for (let i = 0; i < result.length - 1; i++) {
|
||||
expect(result[i].probability).toBeGreaterThanOrEqual(result[i + 1].probability);
|
||||
}
|
||||
});
|
||||
|
||||
it("returns correct class indices", () => {
|
||||
const probs = new Float32Array([0.1, 0.5, 0.3, 0.9, 0.2]);
|
||||
const result = getTopKFloat32(probs, 3);
|
||||
expect(result[0].classIndex).toBe(3);
|
||||
expect(result[1].classIndex).toBe(1);
|
||||
expect(result[2].classIndex).toBe(2);
|
||||
});
|
||||
|
||||
it("handles large arrays (95 classes)", () => {
|
||||
const probs = new Float32Array(95);
|
||||
// Set a few high values
|
||||
probs[0] = 0.4;
|
||||
probs[5] = 0.3;
|
||||
probs[10] = 0.2;
|
||||
probs[20] = 0.05;
|
||||
probs[30] = 0.03;
|
||||
// Rest are small
|
||||
for (let i = 0; i < 95; i++) {
|
||||
if (probs[i] === 0) probs[i] = 0.001;
|
||||
}
|
||||
const result = getTopKFloat32(probs, 5);
|
||||
expect(result).toHaveLength(5);
|
||||
expect(result[0].classIndex).toBe(0);
|
||||
expect(result[0].probability).toBeCloseTo(0.4, 5);
|
||||
});
|
||||
});
|
||||
|
||||
describe("filterByConfidence", () => {
|
||||
it("removes predictions below default threshold", () => {
|
||||
const predictions = [
|
||||
{ classIndex: 0, probability: 0.5 },
|
||||
{ classIndex: 1, probability: 0.3 },
|
||||
{ classIndex: 2, probability: 0.1 },
|
||||
{ classIndex: 3, probability: 0.05 },
|
||||
];
|
||||
const result = filterByConfidence(predictions);
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].classIndex).toBe(0);
|
||||
expect(result[1].classIndex).toBe(1);
|
||||
});
|
||||
|
||||
it("uses custom threshold", () => {
|
||||
const predictions = [
|
||||
{ classIndex: 0, probability: 0.5 },
|
||||
{ classIndex: 1, probability: 0.3 },
|
||||
{ classIndex: 2, probability: 0.1 },
|
||||
];
|
||||
const result = filterByConfidence(predictions, 0.25);
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("returns empty array when all below threshold", () => {
|
||||
const predictions = [
|
||||
{ classIndex: 0, probability: 0.1 },
|
||||
{ classIndex: 1, probability: 0.05 },
|
||||
];
|
||||
const result = filterByConfidence(predictions, 0.2);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("returns all predictions when all above threshold", () => {
|
||||
const predictions = [
|
||||
{ classIndex: 0, probability: 0.5 },
|
||||
{ classIndex: 1, probability: 0.3 },
|
||||
];
|
||||
const result = filterByConfidence(predictions, 0.1);
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("keeps predictions at exactly the threshold", () => {
|
||||
const predictions = [
|
||||
{ classIndex: 0, probability: 0.15 },
|
||||
{ classIndex: 1, probability: 0.14 },
|
||||
];
|
||||
const result = filterByConfidence(predictions, 0.15);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].classIndex).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("DEFAULT_MIN_CONFIDENCE", () => {
|
||||
it("is 0.15", () => {
|
||||
expect(DEFAULT_MIN_CONFIDENCE).toBe(0.15);
|
||||
});
|
||||
});
|
||||
204
src/lib/ml/confidence.ts
Normal file
204
src/lib/ml/confidence.ts
Normal file
@@ -0,0 +1,204 @@
|
||||
/**
|
||||
* Confidence calibration and threshold logic for ML predictions.
|
||||
*
|
||||
* Provides softmax conversion, confidence calibration, and threshold-based
|
||||
* filtering of predictions.
|
||||
*/
|
||||
|
||||
import type { ConfidenceLabel, ConfidenceResult, RawPrediction } from "@/lib/types";
|
||||
|
||||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
/** Minimum confidence threshold — predictions below this are filtered out */
|
||||
export const DEFAULT_MIN_CONFIDENCE = 0.15;
|
||||
|
||||
/** Confidence label thresholds */
|
||||
const CONFIDENCE_THRESHOLDS = {
|
||||
HIGH: 0.8,
|
||||
MEDIUM: 0.5,
|
||||
} as const;
|
||||
|
||||
// ─── Softmax ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Apply softmax to a vector of logits, converting them to probabilities.
|
||||
*
|
||||
* Uses numerically stable softmax: subtracts max before exp() to avoid overflow.
|
||||
*
|
||||
* @param logits - Array of raw model output values
|
||||
* @returns Array of probabilities that sum to ~1.0
|
||||
*/
|
||||
export function softmax(logits: number[]): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const expValues = logits.map((l) => Math.exp(l - maxLogit));
|
||||
const sumExp = expValues.reduce((a, b) => a + b, 0);
|
||||
|
||||
if (sumExp === 0) {
|
||||
// Degenerate case: all logits are -Infinity
|
||||
const uniform = 1 / logits.length;
|
||||
return logits.map(() => uniform);
|
||||
}
|
||||
|
||||
return expValues.map((e) => e / sumExp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply softmax to a Float32Array of logits.
|
||||
*
|
||||
* @param logits - Float32Array of raw model output values
|
||||
* @returns Float32Array of probabilities that sum to ~1.0
|
||||
*/
|
||||
export function softmaxFloat32(logits: Float32Array): Float32Array {
|
||||
const maxLogit = -Infinity;
|
||||
let actualMax = maxLogit;
|
||||
for (let i = 0; i < logits.length; i++) {
|
||||
if (logits[i] > actualMax) actualMax = logits[i];
|
||||
}
|
||||
|
||||
const expValues = new Float32Array(logits.length);
|
||||
let sumExp = 0;
|
||||
for (let i = 0; i < logits.length; i++) {
|
||||
expValues[i] = Math.exp(logits[i] - actualMax);
|
||||
sumExp += expValues[i];
|
||||
}
|
||||
|
||||
if (sumExp === 0) {
|
||||
const uniform = 1 / logits.length;
|
||||
return new Float32Array(logits.length).fill(uniform);
|
||||
}
|
||||
|
||||
for (let i = 0; i < expValues.length; i++) {
|
||||
expValues[i] /= sumExp;
|
||||
}
|
||||
|
||||
return expValues;
|
||||
}
|
||||
|
||||
// ─── Confidence Calibration ──────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Calibrate a raw probability into an adjusted confidence score with a label.
|
||||
*
|
||||
* Applies a mild calibration that slightly adjusts raw softmax probabilities
|
||||
* to account for model overconfidence. Uses a linear calibration:
|
||||
* adjusted = rawProb * calibrationFactor
|
||||
* where calibrationFactor ≈ 1.0 (default 1.02) to slightly boost
|
||||
* well-separated predictions while keeping the value in [0, 1].
|
||||
*
|
||||
* The calibrated value is clamped to [0, 1] and labeled using thresholds:
|
||||
* high ≥ 0.8
|
||||
* medium ≥ 0.5
|
||||
* low < 0.5
|
||||
*
|
||||
* @param rawProb - Raw softmax probability (0–1)
|
||||
* @param calibrationFactor - Linear calibration factor (default 1.02)
|
||||
* @returns { adjusted, label }
|
||||
*/
|
||||
export function calibrateConfidence(
|
||||
rawProb: number,
|
||||
calibrationFactor = 1.02,
|
||||
): ConfidenceResult {
|
||||
const adjusted = Math.min(1, Math.max(0, rawProb * calibrationFactor));
|
||||
const label = getConfidenceLabel(adjusted);
|
||||
|
||||
return {
|
||||
raw: roundToDecimals(rawProb, 4),
|
||||
adjusted: roundToDecimals(adjusted, 4),
|
||||
label,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the confidence label for a given score.
|
||||
*
|
||||
* Thresholds:
|
||||
* high ≥ 0.8
|
||||
* medium ≥ 0.5
|
||||
* low < 0.5
|
||||
*
|
||||
* @param score - Confidence score (0–1)
|
||||
* @returns Confidence label
|
||||
*/
|
||||
export function getConfidenceLabel(score: number): ConfidenceLabel {
|
||||
if (score >= CONFIDENCE_THRESHOLDS.HIGH) return "high";
|
||||
if (score >= CONFIDENCE_THRESHOLDS.MEDIUM) return "medium";
|
||||
return "low";
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply sigmoid function: 1 / (1 + exp(-x))
|
||||
*/
|
||||
function sigmoid(x: number): number {
|
||||
return 1 / (1 + Math.exp(-x));
|
||||
}
|
||||
|
||||
/**
|
||||
* Round a number to a given number of decimal places.
|
||||
*/
|
||||
function roundToDecimals(value: number, decimals: number): number {
|
||||
const factor = Math.pow(10, decimals);
|
||||
return Math.round(value * factor) / factor;
|
||||
}
|
||||
|
||||
// ─── Top-K Extraction ────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Extract the top-K predictions from a probability array.
|
||||
*
|
||||
* @param probabilities - Array of probabilities (from softmax)
|
||||
* @param k - Number of top predictions to return (default 5)
|
||||
* @returns Array of { classIndex, probability } sorted by probability descending
|
||||
*/
|
||||
export function getTopK(
|
||||
probabilities: number[],
|
||||
k = 5,
|
||||
): RawPrediction[] {
|
||||
// Create indexed pairs
|
||||
const indexed = probabilities.map((prob, index) => ({
|
||||
classIndex: index,
|
||||
probability: prob,
|
||||
}));
|
||||
|
||||
// Sort by probability descending
|
||||
indexed.sort((a, b) => b.probability - a.probability);
|
||||
|
||||
// Take top K
|
||||
return indexed.slice(0, k);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract top-K predictions from a Float32Array of probabilities.
|
||||
*
|
||||
* @param probabilities - Float32Array of probabilities
|
||||
* @param k - Number of top predictions (default 5)
|
||||
* @returns Array of { classIndex, probability } sorted descending
|
||||
*/
|
||||
export function getTopKFloat32(
|
||||
probabilities: Float32Array,
|
||||
k = 5,
|
||||
): RawPrediction[] {
|
||||
const indexed: Array<{ classIndex: number; probability: number }> = [];
|
||||
for (let i = 0; i < probabilities.length; i++) {
|
||||
indexed.push({ classIndex: i, probability: probabilities[i] });
|
||||
}
|
||||
|
||||
indexed.sort((a, b) => b.probability - a.probability);
|
||||
|
||||
return indexed.slice(0, k);
|
||||
}
|
||||
|
||||
// ─── Filtering ───────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Filter predictions by minimum confidence threshold.
|
||||
*
|
||||
* @param predictions - Raw predictions from getTopK()
|
||||
* @param minConfidence - Minimum probability threshold (default 0.15)
|
||||
* @returns Filtered predictions array
|
||||
*/
|
||||
export function filterByConfidence(
|
||||
predictions: RawPrediction[],
|
||||
minConfidence = DEFAULT_MIN_CONFIDENCE,
|
||||
): RawPrediction[] {
|
||||
return predictions.filter((p) => p.probability >= minConfidence);
|
||||
}
|
||||
244
src/lib/ml/inference.test.ts
Normal file
244
src/lib/ml/inference.test.ts
Normal file
@@ -0,0 +1,244 @@
|
||||
/**
|
||||
* Unit tests for lib/ml/inference.ts
|
||||
*
|
||||
* Tests:
|
||||
* - validateInput rejects non-Float32Array
|
||||
* - validateInput rejects wrong-length arrays
|
||||
* - validateInput rejects NaN/Infinity values
|
||||
* - validateInput accepts correct tensor
|
||||
* - createZeroTensor produces correct shape
|
||||
* - createRandomTensor produces correct shape with finite values
|
||||
* - runInference returns InferenceResult with predictions array
|
||||
* - runInference returns exactly top-K predictions
|
||||
* - runInference predictions are sorted descending
|
||||
* - runInference includes inferenceTimeMs
|
||||
* - runInference completes under 3 seconds
|
||||
* - runBatchInference processes multiple images
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import {
|
||||
runInference,
|
||||
validateInput,
|
||||
createZeroTensor,
|
||||
createRandomTensor,
|
||||
runBatchInference,
|
||||
INPUT_SIZE,
|
||||
INPUT_SHAPE,
|
||||
DEFAULT_TOP_K,
|
||||
} from "@/lib/ml/inference";
|
||||
import { resetModelCache } from "@/lib/ml/model-loader";
|
||||
|
||||
describe("validateInput", () => {
|
||||
it("rejects non-Float32Array", () => {
|
||||
expect(() => validateInput([1, 2, 3] as any)).toThrow("Expected Float32Array input");
|
||||
});
|
||||
|
||||
it("rejects wrong-length arrays", () => {
|
||||
const tensor = new Float32Array(100);
|
||||
expect(() => validateInput(tensor)).toThrow(`Expected tensor of length ${INPUT_SIZE}`);
|
||||
});
|
||||
|
||||
it("rejects NaN values", () => {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
tensor[50] = NaN;
|
||||
expect(() => validateInput(tensor)).toThrow("non-finite value");
|
||||
});
|
||||
|
||||
it("rejects Infinity values", () => {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
tensor[50] = Infinity;
|
||||
expect(() => validateInput(tensor)).toThrow("non-finite value");
|
||||
});
|
||||
|
||||
it("rejects -Infinity values", () => {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
tensor[50] = -Infinity;
|
||||
expect(() => validateInput(tensor)).toThrow("non-finite value");
|
||||
});
|
||||
|
||||
it("accepts correct tensor", () => {
|
||||
const tensor = createZeroTensor();
|
||||
expect(() => validateInput(tensor)).not.toThrow();
|
||||
});
|
||||
|
||||
it("accepts tensor with negative values", () => {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
for (let i = 0; i < INPUT_SIZE; i++) {
|
||||
tensor[i] = -2;
|
||||
}
|
||||
expect(() => validateInput(tensor)).not.toThrow();
|
||||
});
|
||||
|
||||
it("accepts tensor with values near zero", () => {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
for (let i = 0; i < INPUT_SIZE; i++) {
|
||||
tensor[i] = 0.0001;
|
||||
}
|
||||
expect(() => validateInput(tensor)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("createZeroTensor", () => {
|
||||
it("produces Float32Array", () => {
|
||||
const tensor = createZeroTensor();
|
||||
expect(tensor).toBeInstanceOf(Float32Array);
|
||||
});
|
||||
|
||||
it("has correct length", () => {
|
||||
const tensor = createZeroTensor();
|
||||
expect(tensor.length).toBe(INPUT_SIZE);
|
||||
});
|
||||
|
||||
it("has correct shape dimensions", () => {
|
||||
const expectedLength = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3];
|
||||
expect(INPUT_SIZE).toBe(expectedLength);
|
||||
});
|
||||
|
||||
it("all values are zero", () => {
|
||||
const tensor = createZeroTensor();
|
||||
expect(tensor.every((v) => v === 0)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createRandomTensor", () => {
|
||||
it("produces Float32Array", () => {
|
||||
const tensor = createRandomTensor();
|
||||
expect(tensor).toBeInstanceOf(Float32Array);
|
||||
});
|
||||
|
||||
it("has correct length", () => {
|
||||
const tensor = createRandomTensor();
|
||||
expect(tensor.length).toBe(INPUT_SIZE);
|
||||
});
|
||||
|
||||
it("all values are finite", () => {
|
||||
const tensor = createRandomTensor();
|
||||
expect(tensor.every((v) => Number.isFinite(v))).toBe(true);
|
||||
});
|
||||
|
||||
it("produces varied values", () => {
|
||||
const tensor = createRandomTensor();
|
||||
const uniqueValues = new Set(tensor.map((v) => v.toFixed(4)));
|
||||
expect(uniqueValues.size).toBeGreaterThan(100);
|
||||
});
|
||||
|
||||
it("passes validateInput", () => {
|
||||
const tensor = createRandomTensor();
|
||||
expect(() => validateInput(tensor)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("INPUT_SHAPE and INPUT_SIZE", () => {
|
||||
it("INPUT_SHAPE is [1, 3, 160, 160]", () => {
|
||||
expect(INPUT_SHAPE).toEqual([1, 3, 160, 160]);
|
||||
});
|
||||
|
||||
it("INPUT_SIZE equals 3 * 160 * 160", () => {
|
||||
expect(INPUT_SIZE).toBe(3 * 160 * 160);
|
||||
});
|
||||
|
||||
it("DEFAULT_TOP_K is 5", () => {
|
||||
expect(DEFAULT_TOP_K).toBe(5);
|
||||
});
|
||||
});
|
||||
|
||||
describe("runInference", () => {
|
||||
beforeEach(() => {
|
||||
resetModelCache();
|
||||
});
|
||||
|
||||
it("returns InferenceResult with predictions array", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor);
|
||||
expect(result.predictions).toBeDefined();
|
||||
expect(Array.isArray(result.predictions)).toBe(true);
|
||||
}, 10000);
|
||||
|
||||
it("returns exactly top-K predictions by default", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor);
|
||||
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
|
||||
}, 10000);
|
||||
|
||||
it("returns custom top-K predictions", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor, 3);
|
||||
expect(result.predictions.length).toBe(3);
|
||||
}, 10000);
|
||||
|
||||
it("predictions are sorted by probability descending", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor);
|
||||
for (let i = 0; i < result.predictions.length - 1; i++) {
|
||||
expect(result.predictions[i].probability).toBeGreaterThanOrEqual(
|
||||
result.predictions[i + 1].probability,
|
||||
);
|
||||
}
|
||||
}, 10000);
|
||||
|
||||
it("includes inferenceTimeMs", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor);
|
||||
expect(result.inferenceTimeMs).toBeDefined();
|
||||
expect(typeof result.inferenceTimeMs).toBe("number");
|
||||
expect(result.inferenceTimeMs).toBeGreaterThan(0);
|
||||
}, 10000);
|
||||
|
||||
it("completes under 3 seconds", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const start = performance.now();
|
||||
const result = await runInference(tensor);
|
||||
const elapsed = performance.now() - start;
|
||||
expect(elapsed).toBeLessThan(3000);
|
||||
expect(result.inferenceTimeMs).toBeLessThan(3000);
|
||||
}, 10000);
|
||||
|
||||
it("each prediction has classIndex and probability", async () => {
|
||||
const tensor = createRandomTensor();
|
||||
const result = await runInference(tensor);
|
||||
for (const pred of result.predictions) {
|
||||
expect(pred.classIndex).toBeDefined();
|
||||
expect(typeof pred.classIndex).toBe("number");
|
||||
expect(pred.probability).toBeDefined();
|
||||
expect(typeof pred.probability).toBe("number");
|
||||
expect(pred.probability).toBeGreaterThanOrEqual(0);
|
||||
expect(pred.probability).toBeLessThanOrEqual(1);
|
||||
}
|
||||
}, 10000);
|
||||
|
||||
it("throws on invalid input", async () => {
|
||||
const badTensor = new Float32Array(100);
|
||||
await expect(runInference(badTensor)).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("runBatchInference", () => {
|
||||
beforeEach(() => {
|
||||
resetModelCache();
|
||||
});
|
||||
|
||||
it("processes multiple images", async () => {
|
||||
const tensors = [createRandomTensor(), createRandomTensor(), createRandomTensor()];
|
||||
const results = await runBatchInference(tensors);
|
||||
expect(results).toHaveLength(3);
|
||||
for (const result of results) {
|
||||
expect(result.predictions.length).toBe(DEFAULT_TOP_K);
|
||||
expect(result.inferenceTimeMs).toBeGreaterThan(0);
|
||||
}
|
||||
}, 30000);
|
||||
|
||||
it("each result is independent", async () => {
|
||||
const tensors = [createRandomTensor(), createRandomTensor()];
|
||||
const results = await runBatchInference(tensors);
|
||||
// Results should differ (different random inputs → different predictions)
|
||||
expect(results[0].predictions[0].classIndex).toBeDefined();
|
||||
expect(results[1].predictions[0].classIndex).toBeDefined();
|
||||
}, 15000);
|
||||
|
||||
it("accepts custom top-K", async () => {
|
||||
const tensors = [createRandomTensor()];
|
||||
const results = await runBatchInference(tensors, 3);
|
||||
expect(results[0].predictions.length).toBe(3);
|
||||
}, 15000);
|
||||
});
|
||||
137
src/lib/ml/inference.ts
Normal file
137
src/lib/ml/inference.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* ML inference pipeline for plant disease classification.
|
||||
*
|
||||
* Accepts a preprocessed image tensor, runs it through the model,
|
||||
* applies softmax, extracts top-K predictions, and returns results
|
||||
* with timing metadata.
|
||||
*/
|
||||
|
||||
import type { InferenceResult, RawPrediction } from "@/lib/types";
|
||||
import { getModel } from "./model-loader";
|
||||
import { softmaxFloat32, getTopKFloat32 } from "./confidence";
|
||||
|
||||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
/** Number of top predictions to return */
|
||||
export const DEFAULT_TOP_K = 5;
|
||||
|
||||
/** Input tensor shape: [batch=1, channels=3, height=160, width=160] */
|
||||
export const INPUT_SHAPE: [number, number, number, number] = [1, 3, 160, 160];
|
||||
|
||||
/** Expected input tensor length */
|
||||
export const INPUT_SIZE = INPUT_SHAPE[1] * INPUT_SHAPE[2] * INPUT_SHAPE[3]; // 3 * 160 * 160 = 76800
|
||||
|
||||
// ─── Main Inference ──────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Run the full inference pipeline on a preprocessed image tensor.
|
||||
*
|
||||
* @param imageTensor - Normalized Float32Array of shape [1, 3, 160, 160] (NCHW)
|
||||
* @param topK - Number of top predictions to return (default 5)
|
||||
* @returns InferenceResult with top-K predictions and timing
|
||||
*/
|
||||
export async function runInference(
|
||||
imageTensor: Float32Array,
|
||||
topK = DEFAULT_TOP_K,
|
||||
): Promise<InferenceResult> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Validate input
|
||||
validateInput(imageTensor);
|
||||
|
||||
// Get model (lazy loads on first call)
|
||||
const model = await getModel();
|
||||
|
||||
// Run model forward pass
|
||||
const { logits, inferenceTimeMs } = await model.predict(imageTensor);
|
||||
|
||||
// Apply softmax to convert logits to probabilities
|
||||
const probabilities = softmaxFloat32(logits);
|
||||
|
||||
// Extract top-K predictions
|
||||
const predictions = getTopKFloat32(probabilities, topK);
|
||||
|
||||
const totalTime = performance.now() - startTime;
|
||||
|
||||
return {
|
||||
predictions,
|
||||
inferenceTimeMs: Math.round(totalTime),
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Input Validation ────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Validate that the input tensor has the expected shape and type.
|
||||
*
|
||||
* @param tensor - Input tensor to validate
|
||||
* @throws Error if tensor is invalid
|
||||
*/
|
||||
export function validateInput(tensor: Float32Array): void {
|
||||
if (!(tensor instanceof Float32Array)) {
|
||||
throw new Error(`Expected Float32Array input, got ${typeof tensor}`);
|
||||
}
|
||||
|
||||
if (tensor.length !== INPUT_SIZE) {
|
||||
throw new Error(
|
||||
`Expected tensor of length ${INPUT_SIZE} (shape ${INPUT_SHAPE.join("×")}), ` +
|
||||
`got ${tensor.length}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check for NaN/Infinity values
|
||||
for (let i = 0; i < tensor.length; i++) {
|
||||
if (!Number.isFinite(tensor[i])) {
|
||||
throw new Error(`Tensor contains non-finite value at index ${i}: ${tensor[i]}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Batch Inference ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Run inference on multiple images.
|
||||
*
|
||||
* Currently runs sequentially. For true batching, the model itself would need
|
||||
* to support batch input.
|
||||
*
|
||||
* @param tensors - Array of preprocessed image tensors
|
||||
* @param topK - Number of top predictions per image
|
||||
* @returns Array of inference results
|
||||
*/
|
||||
export async function runBatchInference(
|
||||
tensors: Float32Array[],
|
||||
topK = DEFAULT_TOP_K,
|
||||
): Promise<InferenceResult[]> {
|
||||
const results: InferenceResult[] = [];
|
||||
|
||||
for (const tensor of tensors) {
|
||||
results.push(await runInference(tensor, topK));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// ─── Utility ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a zero-filled input tensor for testing.
|
||||
*
|
||||
* @returns Float32Array of shape [1, 3, 224, 224]
|
||||
*/
|
||||
export function createZeroTensor(): Float32Array {
|
||||
return new Float32Array(INPUT_SIZE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a random input tensor for testing.
|
||||
*
|
||||
* @returns Float32Array of shape [1, 3, 224, 224] with random values
|
||||
*/
|
||||
export function createRandomTensor(): Float32Array {
|
||||
const tensor = new Float32Array(INPUT_SIZE);
|
||||
for (let i = 0; i < tensor.length; i++) {
|
||||
tensor[i] = (Math.random() * 2 - 1) * 2; // Range roughly -2 to 2
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
199
src/lib/ml/labels.test.ts
Normal file
199
src/lib/ml/labels.test.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
/**
|
||||
* Unit tests for lib/ml/labels.ts
|
||||
*
|
||||
* The model has 38 PlantVillage classes. Some map to the app's
|
||||
* knowledge base disease IDs, others map to "unknown".
|
||||
*
|
||||
* Known mappings:
|
||||
* - indices 3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37 → "healthy"
|
||||
* - index 20 (Potato___Early_blight) → "early-blight"
|
||||
* - index 21 (Potato___Late_blight) → "late-blight"
|
||||
* - index 25 (Squash___Powdery_mildew) → "squash-powdery-mildew"
|
||||
* - index 26 (Strawberry___Leaf_scorch) → "strawberry-leaf-scorch"
|
||||
* - index 28 (Tomato___Bacterial_spot) → "bacterial-leaf-spot-tomato"
|
||||
* - index 29 (Tomato___Early_blight) → "early-blight" (duplicate)
|
||||
* - index 30 (Tomato___Late_blight) → "late-blight" (duplicate)
|
||||
* - index 32 (Tomato___Septoria_leaf_spot) → "septoria-leaf-spot"
|
||||
* - index 37 (Tomato___healthy) → "healthy"
|
||||
* - all others → "unknown"
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
import {
|
||||
INDEX_TO_DISEASE_ID,
|
||||
DISEASE_ID_TO_INDEX,
|
||||
getDiseaseIdForIndex,
|
||||
getIndexForDiseaseId,
|
||||
isRealDisease,
|
||||
getAllDiseaseIds,
|
||||
NUM_CLASSES,
|
||||
getPlantVillageClassName,
|
||||
} from "@/lib/ml/labels";
|
||||
|
||||
describe("Constants", () => {
|
||||
it("NUM_CLASSES is 38 (PlantVillage)", () => {
|
||||
expect(NUM_CLASSES).toBe(38);
|
||||
});
|
||||
|
||||
it("all 38 indices are mapped", () => {
|
||||
const keys = Object.keys(INDEX_TO_DISEASE_ID).map(Number);
|
||||
expect(keys.length).toBe(38);
|
||||
for (let i = 0; i < 38; i++) {
|
||||
expect(keys).toContain(i);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — healthy indices", () => {
|
||||
const healthyIndices = [3, 4, 6, 10, 14, 17, 19, 22, 23, 24, 27, 37];
|
||||
|
||||
for (const idx of healthyIndices) {
|
||||
it(`index ${idx} maps to "healthy"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe("healthy");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — known disease mappings", () => {
|
||||
const cases: Array<{ index: number; expected: string; name: string }> = [
|
||||
{ index: 20, expected: "early-blight", name: "Potato___Early_blight" },
|
||||
{ index: 21, expected: "late-blight", name: "Potato___Late_blight" },
|
||||
{ index: 25, expected: "squash-powdery-mildew", name: "Squash___Powdery_mildew" },
|
||||
{ index: 26, expected: "strawberry-leaf-scorch", name: "Strawberry___Leaf_scorch" },
|
||||
{ index: 28, expected: "bacterial-leaf-spot-tomato", name: "Tomato___Bacterial_spot" },
|
||||
{ index: 29, expected: "early-blight", name: "Tomato___Early_blight" },
|
||||
{ index: 30, expected: "late-blight", name: "Tomato___Late_blight" },
|
||||
{ index: 32, expected: "septoria-leaf-spot", name: "Tomato___Septoria_leaf_spot" },
|
||||
];
|
||||
|
||||
for (const { index, expected, name } of cases) {
|
||||
it(`index ${index} (${name}) maps to "${expected}"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[index]).toBe(expected);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("INDEX_TO_DISEASE_ID — unknown (unmapped) indices", () => {
|
||||
const unknownIndices = [0, 1, 2, 5, 7, 8, 9, 11, 12, 13, 15, 16, 18, 31, 33, 34, 35, 36];
|
||||
|
||||
for (const idx of unknownIndices) {
|
||||
it(`index ${idx} maps to "unknown"`, () => {
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe("unknown");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("DISEASE_ID_TO_INDEX", () => {
|
||||
it("maps 'early-blight' to first occurrence (index 20)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["early-blight"]).toBe(20);
|
||||
});
|
||||
|
||||
it("maps 'late-blight' to first occurrence (index 21)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["late-blight"]).toBe(21);
|
||||
});
|
||||
|
||||
it("maps 'septoria-leaf-spot' to index 32", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["septoria-leaf-spot"]).toBe(32);
|
||||
});
|
||||
|
||||
it("maps 'healthy' to index 3 (first healthy index)", () => {
|
||||
expect(DISEASE_ID_TO_INDEX["healthy"]).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Bidirectional mapping", () => {
|
||||
it("every index round-trips correctly", () => {
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
const id = INDEX_TO_DISEASE_ID[i];
|
||||
const idx = DISEASE_ID_TO_INDEX[id];
|
||||
expect(INDEX_TO_DISEASE_ID[idx]).toBe(id);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("getDiseaseIdForIndex", () => {
|
||||
it("returns 'unknown' for out-of-range positive index", () => {
|
||||
expect(getDiseaseIdForIndex(100)).toBe("unknown");
|
||||
});
|
||||
|
||||
it("returns 'unknown' for negative index", () => {
|
||||
expect(getDiseaseIdForIndex(-1)).toBe("unknown");
|
||||
});
|
||||
|
||||
it("returns correct ID for valid index", () => {
|
||||
expect(getDiseaseIdForIndex(20)).toBe("early-blight");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getIndexForDiseaseId", () => {
|
||||
it("returns -1 for unknown disease ID", () => {
|
||||
expect(getIndexForDiseaseId("nonexistent-disease")).toBe(-1);
|
||||
});
|
||||
|
||||
it("returns -1 for empty string", () => {
|
||||
expect(getIndexForDiseaseId("")).toBe(-1);
|
||||
});
|
||||
|
||||
it("is case-insensitive", () => {
|
||||
expect(getIndexForDiseaseId("EARLY-BLIGHT")).toBe(20);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isRealDisease", () => {
|
||||
it("returns false for 'healthy'", () => {
|
||||
expect(isRealDisease("healthy")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false for 'unknown'", () => {
|
||||
expect(isRealDisease("unknown")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true for known disease IDs", () => {
|
||||
expect(isRealDisease("early-blight")).toBe(true);
|
||||
expect(isRealDisease("septoria-leaf-spot")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true for arbitrary non-special strings", () => {
|
||||
expect(isRealDisease("some-disease")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPlantVillageClassName", () => {
|
||||
it("returns correct class name for tomato healthy", () => {
|
||||
expect(getPlantVillageClassName(37)).toBe("Tomato___healthy");
|
||||
});
|
||||
|
||||
it("returns correct class name for potato early blight", () => {
|
||||
expect(getPlantVillageClassName(20)).toBe("Potato___Early_blight");
|
||||
});
|
||||
|
||||
it("returns 'unknown' for out-of-range index", () => {
|
||||
expect(getPlantVillageClassName(100)).toBe("unknown");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAllDiseaseIds", () => {
|
||||
it("returns only mapped disease IDs", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
expect(ids).toContain("early-blight");
|
||||
expect(ids).toContain("late-blight");
|
||||
expect(ids).toContain("squash-powdery-mildew");
|
||||
expect(ids).toContain("strawberry-leaf-scorch");
|
||||
expect(ids).toContain("bacterial-leaf-spot-tomato");
|
||||
expect(ids).toContain("septoria-leaf-spot");
|
||||
});
|
||||
|
||||
it("excludes 'healthy'", () => {
|
||||
expect(getAllDiseaseIds()).not.toContain("healthy");
|
||||
});
|
||||
|
||||
it("excludes 'unknown'", () => {
|
||||
expect(getAllDiseaseIds()).not.toContain("unknown");
|
||||
});
|
||||
|
||||
it("has no duplicates", () => {
|
||||
const ids = getAllDiseaseIds();
|
||||
const uniqueIds = new Set(ids);
|
||||
expect(uniqueIds.size).toBe(ids.length);
|
||||
});
|
||||
});
|
||||
237
src/lib/ml/labels.ts
Normal file
237
src/lib/ml/labels.ts
Normal file
@@ -0,0 +1,237 @@
|
||||
/**
|
||||
* Class label mapping for the plant disease classifier model.
|
||||
*
|
||||
* This model is a MobileNetV2 trained on the PlantVillage dataset
|
||||
* with 38 classes (14 crops × diseases/healthy).
|
||||
*
|
||||
* Model output shape: [1, NUM_CLASSES] where NUM_CLASSES = 38
|
||||
*
|
||||
* Index layout (from labels_pv_original.json):
|
||||
* 0 → Apple___Apple_scab
|
||||
* 1 → Apple___Black_rot
|
||||
* 2 → Apple___Cedar_apple_rust
|
||||
* 3 → Apple___healthy
|
||||
* 4 → Blueberry___healthy
|
||||
* 5 → Cherry_(including_sour)___Powdery_mildew
|
||||
* 6 → Cherry_(including_sour)___healthy
|
||||
* 7 → Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
|
||||
* 8 → Corn_(maize)___Common_rust_
|
||||
* 9 → Corn_(maize)___Northern_Leaf_Blight
|
||||
* 10 → Corn_(maize)___healthy
|
||||
* 11 → Grape___Black_rot
|
||||
* 12 → Grape___Esca_(Black_Measles)
|
||||
* 13 → Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
|
||||
* 14 → Grape___healthy
|
||||
* 15 → Orange___Haunglongbing_(Citrus_greening)
|
||||
* 16 → Peach___Bacterial_spot
|
||||
* 17 → Peach___healthy
|
||||
* 18 → Pepper,_bell___Bacterial_spot
|
||||
* 19 → Pepper,_bell___healthy
|
||||
* 20 → Potato___Early_blight
|
||||
* 21 → Potato___Late_blight
|
||||
* 22 → Potato___healthy
|
||||
* 23 → Raspberry___healthy
|
||||
* 24 → Soybean___healthy
|
||||
* 25 → Squash___Powdery_mildew
|
||||
* 26 → Strawberry___Leaf_scorch
|
||||
* 27 → Strawberry___healthy
|
||||
* 28 → Tomato___Bacterial_spot
|
||||
* 29 → Tomato___Early_blight
|
||||
* 30 → Tomato___Late_blight
|
||||
* 31 → Tomato___Leaf_Mold
|
||||
* 32 → Tomato___Septoria_leaf_spot
|
||||
* 33 → Tomato___Spider_mites Two-spotted_spider_mite
|
||||
* 34 → Tomato___Target_Spot
|
||||
* 35 → Tomato___Tomato_Yellow_Leaf_Curl_Virus
|
||||
* 36 → Tomato___Tomato_mosaic_virus
|
||||
* 37 → Tomato___healthy
|
||||
*
|
||||
* Some PlantVillage classes overlap with this app's knowledge base.
|
||||
* Exact class name → disease ID mappings:
|
||||
* Potato___Early_blight → "early-blight"
|
||||
* Potato___Late_blight → "late-blight"
|
||||
* Squash___Powdery_mildew → "squash-powdery-mildew"
|
||||
* Strawberry___Leaf_scorch → "strawberry-leaf-scorch"
|
||||
* Tomato___Bacterial_spot → "bacterial-leaf-spot-tomato"
|
||||
* Tomato___Early_blight → "early-blight"
|
||||
* Tomato___Late_blight → "late-blight"
|
||||
* Tomato___Septoria_leaf_spot → "septoria-leaf-spot"
|
||||
* All other classes map to "unknown" and are filtered out during enrichment.
|
||||
*
|
||||
* After fine-tuning to the app's 93 disease classes, this file will be
|
||||
* rewritten to match the new model's output layer.
|
||||
*/
|
||||
|
||||
// ─── PlantVillage class names (in model output order) ────────────────────
|
||||
|
||||
const PLANTVILLAGE_CLASSES: string[] = [
|
||||
"Apple___Apple_scab",
|
||||
"Apple___Black_rot",
|
||||
"Apple___Cedar_apple_rust",
|
||||
"Apple___healthy",
|
||||
"Blueberry___healthy",
|
||||
"Cherry_(including_sour)___Powdery_mildew",
|
||||
"Cherry_(including_sour)___healthy",
|
||||
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
|
||||
"Corn_(maize)___Common_rust_",
|
||||
"Corn_(maize)___Northern_Leaf_Blight",
|
||||
"Corn_(maize)___healthy",
|
||||
"Grape___Black_rot",
|
||||
"Grape___Esca_(Black_Measles)",
|
||||
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
|
||||
"Grape___healthy",
|
||||
"Orange___Haunglongbing_(Citrus_greening)",
|
||||
"Peach___Bacterial_spot",
|
||||
"Peach___healthy",
|
||||
"Pepper,_bell___Bacterial_spot",
|
||||
"Pepper,_bell___healthy",
|
||||
"Potato___Early_blight",
|
||||
"Potato___Late_blight",
|
||||
"Potato___healthy",
|
||||
"Raspberry___healthy",
|
||||
"Soybean___healthy",
|
||||
"Squash___Powdery_mildew",
|
||||
"Strawberry___Leaf_scorch",
|
||||
"Strawberry___healthy",
|
||||
"Tomato___Bacterial_spot",
|
||||
"Tomato___Early_blight",
|
||||
"Tomato___Late_blight",
|
||||
"Tomato___Leaf_Mold",
|
||||
"Tomato___Septoria_leaf_spot",
|
||||
"Tomato___Spider_mites Two-spotted_spider_mite",
|
||||
"Tomato___Target_Spot",
|
||||
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
|
||||
"Tomato___Tomato_mosaic_virus",
|
||||
"Tomato___healthy",
|
||||
] as const;
|
||||
|
||||
// ─── PlantVillage → App disease ID mapping ──────────────────────────────
|
||||
|
||||
/**
|
||||
* Maps PlantVillage class names (in the form "Plant___Disease") to
|
||||
* this app's disease IDs. Unmapped classes resolve to "unknown".
|
||||
*/
|
||||
function plantVillageNameToDiseaseId(pvName: string): string {
|
||||
const parts = pvName.split("___");
|
||||
if (parts.length !== 2) {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
const disease = parts[1];
|
||||
|
||||
// Detect "healthy" variants
|
||||
if (disease === "healthy") {
|
||||
return "healthy";
|
||||
}
|
||||
|
||||
// Map exact PlantVillage class names to our disease IDs.
|
||||
// Only map classes where we're confident the correspondence holds.
|
||||
const exactMap: Record<string, string> = {
|
||||
Squash___Powdery_mildew: "squash-powdery-mildew",
|
||||
Strawberry___Leaf_scorch: "strawberry-leaf-scorch",
|
||||
Potato___Early_blight: "early-blight",
|
||||
Potato___Late_blight: "late-blight",
|
||||
Tomato___Bacterial_spot: "bacterial-leaf-spot-tomato",
|
||||
Tomato___Early_blight: "early-blight",
|
||||
Tomato___Late_blight: "late-blight",
|
||||
Tomato___Septoria_leaf_spot: "septoria-leaf-spot",
|
||||
};
|
||||
|
||||
return exactMap[pvName] ?? "unknown";
|
||||
}
|
||||
|
||||
// ─── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
/** Total number of model output classes */
|
||||
export const NUM_CLASSES = PLANTVILLAGE_CLASSES.length; // 38
|
||||
|
||||
/** Index for the "healthy" class — multiple PV indices map to this */
|
||||
export const HEALTHY_INDEX = 0; // First PV healthy class, others also map to this string
|
||||
|
||||
/** First disease index (unused in PV mapping, kept for compatibility) */
|
||||
export const FIRST_DISEASE_INDEX = 0;
|
||||
|
||||
/** Index for the "unknown" catch-all — PV classes we can't map */
|
||||
export const UNKNOWN_INDEX = NUM_CLASSES - 1; // 37 (Tomato___healthy maps to "healthy", not unknown)
|
||||
|
||||
// ─── Index → Disease ID mapping ─────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Map from model output index to app disease ID string.
|
||||
* Built dynamically from PlantVillage class names.
|
||||
*/
|
||||
export const INDEX_TO_DISEASE_ID: Record<number, string> = Object.freeze(
|
||||
(() => {
|
||||
const map: Record<number, string> = {};
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
map[i] = plantVillageNameToDiseaseId(PLANTVILLAGE_CLASSES[i]);
|
||||
}
|
||||
return map;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ─── Disease ID → Index mapping ─────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Map from disease ID string to model output index.
|
||||
* For duplicates (e.g., both potato and tomato "Early_blight" → "early-blight"),
|
||||
* returns the first matching index.
|
||||
*/
|
||||
export const DISEASE_ID_TO_INDEX: Record<string, number> = Object.freeze(
|
||||
(() => {
|
||||
const map: Record<string, number> = {};
|
||||
for (let i = 0; i < NUM_CLASSES; i++) {
|
||||
const id = INDEX_TO_DISEASE_ID[i];
|
||||
// First occurrence wins (potato before tomato for early/late blight)
|
||||
if (map[id] === undefined) {
|
||||
map[id] = i;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ─── Lookup helpers ─────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get the disease ID for a given model output index.
|
||||
* Returns "unknown" for out-of-range indices.
|
||||
*/
|
||||
export function getDiseaseIdForIndex(index: number): string {
|
||||
return INDEX_TO_DISEASE_ID[index] ?? "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the model output index for a given disease ID.
|
||||
* Returns -1 if not found.
|
||||
*/
|
||||
export function getIndexForDiseaseId(diseaseId: string): number {
|
||||
return DISEASE_ID_TO_INDEX[diseaseId.toLowerCase()] ?? -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a disease ID is a real disease (not "healthy" or "unknown").
|
||||
*/
|
||||
export function isRealDisease(diseaseId: string): boolean {
|
||||
return diseaseId !== "healthy" && diseaseId !== "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the PlantVillage display name for a given model output index.
|
||||
*/
|
||||
export function getPlantVillageClassName(index: number): string {
|
||||
return PLANTVILLAGE_CLASSES[index] ?? "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all known disease IDs (excluding "healthy" and "unknown").
|
||||
*/
|
||||
export function getAllDiseaseIds(): string[] {
|
||||
const ids = new Set<string>();
|
||||
for (const id of Object.values(INDEX_TO_DISEASE_ID)) {
|
||||
if (id !== "healthy" && id !== "unknown") {
|
||||
ids.add(id);
|
||||
}
|
||||
}
|
||||
return Array.from(ids);
|
||||
}
|
||||
395
src/lib/ml/model-loader.ts
Normal file
395
src/lib/ml/model-loader.ts
Normal file
@@ -0,0 +1,395 @@
|
||||
/**
|
||||
* Singleton model loader for the plant disease classifier.
|
||||
*
|
||||
* Lazy-loads the TF.js or ONNX model on first call and caches it in memory
|
||||
* via globalThis for subsequent requests. Supports graceful fallback to
|
||||
* mock mode when no model file is present.
|
||||
*
|
||||
* Model files expected at: public/models/plant-disease-classifier/model.json
|
||||
*/
|
||||
|
||||
import fs from "fs/promises";
|
||||
import fsSync from "fs";
|
||||
import path from "path";
|
||||
|
||||
// ─── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Model runtime backend */
|
||||
export type ModelBackend = "tfjs" | "onnx" | "mock";
|
||||
|
||||
/** Model loading status */
|
||||
export interface ModelStatus {
|
||||
/** Whether a real model is loaded */
|
||||
loaded: boolean;
|
||||
/** Backend being used */
|
||||
backend: ModelBackend;
|
||||
/** Model identifier string */
|
||||
modelId: string;
|
||||
/** Number of output classes */
|
||||
numClasses: number;
|
||||
/** Error message if loading failed */
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Result from running the model on input data */
|
||||
export interface ModelOutput {
|
||||
/** Raw logits or probabilities from the model */
|
||||
logits: Float32Array;
|
||||
/** Inference time in milliseconds */
|
||||
inferenceTimeMs: number;
|
||||
}
|
||||
|
||||
/** Model interface abstracted over TF.js / ONNX / mock */
|
||||
export interface PlantDiseaseModel {
|
||||
/** Run inference on a preprocessed image tensor */
|
||||
predict(tensor: Float32Array): Promise<ModelOutput>;
|
||||
/** Get model metadata */
|
||||
getStatus(): ModelStatus;
|
||||
}
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
/** Path to model files relative to project root */
|
||||
const MODEL_DIR = path.join(process.cwd(), "public", "models", "plant-disease-classifier");
|
||||
const MODEL_JSON_PATH = path.join(MODEL_DIR, "model.json");
|
||||
|
||||
/** Model identifier */
|
||||
export const MODEL_ID = "plant-classifier-v1";
|
||||
|
||||
/** Maximum model load time (ms) */
|
||||
const MODEL_LOAD_TIMEOUT = 30_000;
|
||||
|
||||
// ─── Global cache ────────────────────────────────────────────────────────────
|
||||
|
||||
declare global {
|
||||
var __plantDiseaseModel__: PlantDiseaseModel | undefined;
|
||||
var __plantDiseaseModelLoading__: Promise<PlantDiseaseModel> | undefined;
|
||||
}
|
||||
|
||||
// ─── Model Loader ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get the cached model instance, loading it lazily on first call.
|
||||
* Uses globalThis to persist across serverless invocations (within same container).
|
||||
*
|
||||
* @returns Promise resolving to the model (real or mock)
|
||||
*/
|
||||
export async function getModel(): Promise<PlantDiseaseModel> {
|
||||
// Return cached model if available
|
||||
if (globalThis.__plantDiseaseModel__) {
|
||||
return globalThis.__plantDiseaseModel__;
|
||||
}
|
||||
|
||||
// If already loading, wait for the existing promise
|
||||
if (globalThis.__plantDiseaseModelLoading__) {
|
||||
return globalThis.__plantDiseaseModelLoading__;
|
||||
}
|
||||
|
||||
// Start loading
|
||||
const loadingPromise = loadModel();
|
||||
globalThis.__plantDiseaseModelLoading__ = loadingPromise;
|
||||
|
||||
try {
|
||||
const model = await Promise.race([
|
||||
loadingPromise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Model load timed out after ${MODEL_LOAD_TIMEOUT}ms`)),
|
||||
MODEL_LOAD_TIMEOUT,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
globalThis.__plantDiseaseModel__ = model;
|
||||
return model;
|
||||
} finally {
|
||||
globalThis.__plantDiseaseModelLoading__ = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model, attempting TF.js first, then ONNX, then falling back to mock.
|
||||
*/
|
||||
async function loadModel(): Promise<PlantDiseaseModel> {
|
||||
// Check if model files exist
|
||||
const modelExists = await checkModelFiles();
|
||||
|
||||
if (!modelExists) {
|
||||
console.warn(
|
||||
`[model-loader] Model files not found at ${MODEL_DIR}. Using mock model. ` +
|
||||
`Place TF.js model (model.json + weight shards) in public/models/plant-disease-classifier/`,
|
||||
);
|
||||
return createMockModel();
|
||||
}
|
||||
|
||||
// Try TF.js first
|
||||
try {
|
||||
const tfModel = await tryLoadTFJS();
|
||||
if (tfModel) {
|
||||
console.info(`[model-loader] Loaded TF.js model: ${MODEL_ID}`);
|
||||
return tfModel;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
`[model-loader] TF.js load failed (${err instanceof Error ? err.message : "unknown"}). Trying ONNX...`,
|
||||
);
|
||||
}
|
||||
|
||||
// Try ONNX Runtime
|
||||
try {
|
||||
const onnxModel = await tryLoadONNX();
|
||||
if (onnxModel) {
|
||||
console.info(`[model-loader] Loaded ONNX model: ${MODEL_ID}`);
|
||||
return onnxModel;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
`[model-loader] ONNX load failed (${err instanceof Error ? err.message : "unknown"}). Falling back to mock.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Fall back to mock
|
||||
console.warn(`[model-loader] All backends failed. Using mock model.`);
|
||||
return createMockModel();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model files exist on disk.
|
||||
*/
|
||||
async function checkModelFiles(): Promise<boolean> {
|
||||
try {
|
||||
await fs.access(MODEL_JSON_PATH);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── TensorFlow.js Backend ───────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Try to load the model using TensorFlow.js.
|
||||
* Attempts @tensorflow/tfjs-node first (server), falls back to @tensorflow/tfjs.
|
||||
*/
|
||||
async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let tf: any;
|
||||
|
||||
// Monkey-patch: add util.isNullOrUndefined for Node.js 26 compatibility.
|
||||
// @tensorflow/tfjs-node references this function which was removed in Node 15+.
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
const nodeUtil = require("util");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
if (typeof (nodeUtil as any).isNullOrUndefined !== "function") {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(nodeUtil as any).isNullOrUndefined = function (x: unknown): boolean {
|
||||
return x === null || x === undefined;
|
||||
};
|
||||
}
|
||||
|
||||
// Try tfjs-node first (server-side, uses native bindings).
|
||||
// Use dynamic strings so bundlers (Turbopack/webpack) don't trace these
|
||||
// as required dependencies — they are truly optional.
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
const tfjsNode = await import("@tensorflow/tfjs-node" + "");
|
||||
tf = tfjsNode;
|
||||
} catch {
|
||||
// Fall back to browser tfjs
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
tf = await import("@tensorflow/tfjs" + "");
|
||||
} catch {
|
||||
return null; // Neither tfjs package available
|
||||
}
|
||||
}
|
||||
|
||||
// Load the model from file path
|
||||
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
|
||||
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
|
||||
// Reshape NCHW flat array [3*160*160] → [3, 160, 160] → NHWC [1, 160, 160, 3]
|
||||
const inputTensor = tf
|
||||
.tensor3d(Array.from(tensor), [3, 160, 160])
|
||||
.transpose([1, 2, 0])
|
||||
.expandDims(0);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const outputTensor = (await model.predict(inputTensor)) as any;
|
||||
const logits = new Float32Array(await outputTensor.data());
|
||||
|
||||
inputTensor.dispose();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
||||
outputTensor.dispose();
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: performance.now() - startTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: true,
|
||||
backend: "tfjs",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 38, // Original PlantVillage model
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── ONNX Runtime Backend ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Try to load the model using ONNX Runtime.
|
||||
*/
|
||||
async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let ort: any;
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
ort = await import("onnxruntime-node" + "");
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Look for .onnx file in model directory
|
||||
const onnxPath = path.join(MODEL_DIR, "model.onnx");
|
||||
const onnxExists = fsSync.existsSync(onnxPath);
|
||||
|
||||
if (!onnxExists) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const session = await ort.InferenceSession.create(onnxPath);
|
||||
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// ONNX expects NCHW format: [1, 3, 160, 160]
|
||||
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
|
||||
const feeds = { [session.inputNames[0]]: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const outputValues = Object.values(results) as any[];
|
||||
const logits = new Float32Array(outputValues[0].data);
|
||||
|
||||
inputTensor.dispose();
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: performance.now() - startTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: true,
|
||||
backend: "onnx",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 38,
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Mock Model ──────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Create a deterministic mock model for development/demo mode.
|
||||
*
|
||||
* Generates reproducible predictions based on input tensor hash.
|
||||
* This allows the UI to work without a real model file.
|
||||
*/
|
||||
function createMockModel(): PlantDiseaseModel {
|
||||
return {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
// Simulate inference time (50-200ms)
|
||||
const simulatedTime = 50 + Math.random() * 150;
|
||||
await sleep(simulatedTime);
|
||||
|
||||
// Generate deterministic logits from input hash
|
||||
const logits = generateMockLogits(tensor);
|
||||
|
||||
return {
|
||||
logits,
|
||||
inferenceTimeMs: simulatedTime,
|
||||
};
|
||||
},
|
||||
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: false,
|
||||
backend: "mock",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 38,
|
||||
error: "Model files not found. Running in demo mode with mock predictions.",
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate deterministic mock logits from input tensor.
|
||||
* Uses a simple hash of the first few tensor values to create
|
||||
* reproducible but varied predictions.
|
||||
*/
|
||||
function generateMockLogits(tensor: Float32Array): Float32Array {
|
||||
const numClasses = 38;
|
||||
const logits = new Float32Array(numClasses);
|
||||
|
||||
// Simple hash of input for deterministic output
|
||||
let hash = 0;
|
||||
const sampleSize = Math.min(100, tensor.length);
|
||||
for (let i = 0; i < sampleSize; i++) {
|
||||
hash = ((hash << 5) - hash + Math.floor(tensor[i] * 1000)) | 0;
|
||||
}
|
||||
|
||||
// Generate logits using hash as seed
|
||||
// Class 0 (healthy) gets a moderate score
|
||||
logits[0] = (Math.abs(hash % 10) / 10) * 2;
|
||||
|
||||
// Give some disease classes higher scores
|
||||
// This creates a realistic-looking distribution
|
||||
for (let i = 1; i < numClasses - 1; i++) {
|
||||
const seed = ((hash * (i + 1) * 7) % 1000) / 1000;
|
||||
logits[i] = seed * 4 - 1; // Range roughly -1 to 3
|
||||
}
|
||||
|
||||
// Make the top prediction more confident
|
||||
const topIndex = Math.abs(hash % (numClasses - 2)) + 1;
|
||||
logits[topIndex] = 3.5;
|
||||
|
||||
// Second highest
|
||||
const secondIndex = ((topIndex + Math.abs(hash % 10) + 1) % (numClasses - 1)) + 1;
|
||||
logits[secondIndex] = 2.5;
|
||||
|
||||
logits[numClasses - 1] = -2; // "unknown" gets low score
|
||||
|
||||
return logits;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for a given number of milliseconds.
|
||||
*/
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
// ─── Reset (for testing) ─────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Reset the model cache. Useful for testing.
|
||||
*/
|
||||
export function resetModelCache(): void {
|
||||
globalThis.__plantDiseaseModel__ = undefined;
|
||||
globalThis.__plantDiseaseModelLoading__ = undefined;
|
||||
}
|
||||
55
src/lib/server/image-processing-server.test.ts
Normal file
55
src/lib/server/image-processing-server.test.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { mimeTypeToExtension } from "./image-processing-server";
|
||||
|
||||
// Mock sharp dynamically
|
||||
const mockSharp = vi.fn(() => ({
|
||||
resize: vi.fn().mockReturnThis(),
|
||||
jpeg: vi.fn().mockReturnThis(),
|
||||
toBuffer: vi.fn().mockResolvedValue(Buffer.from("resized-image-data")),
|
||||
}));
|
||||
|
||||
vi.doMock("sharp", () => ({
|
||||
default: mockSharp,
|
||||
}));
|
||||
|
||||
describe("mimeTypeToExtension", () => {
|
||||
it("maps image/png to png", () => {
|
||||
expect(mimeTypeToExtension("image/png")).toBe("png");
|
||||
});
|
||||
|
||||
it("maps image/jpeg to jpg", () => {
|
||||
expect(mimeTypeToExtension("image/jpeg")).toBe("jpg");
|
||||
});
|
||||
|
||||
it("maps image/jpg to jpg", () => {
|
||||
expect(mimeTypeToExtension("image/jpg")).toBe("jpg");
|
||||
});
|
||||
|
||||
it("maps image/webp to webp", () => {
|
||||
expect(mimeTypeToExtension("image/webp")).toBe("webp");
|
||||
});
|
||||
|
||||
it("returns jpg for unknown mime types", () => {
|
||||
expect(mimeTypeToExtension("image/bmp")).toBe("jpg");
|
||||
expect(mimeTypeToExtension("unknown/type")).toBe("jpg");
|
||||
});
|
||||
});
|
||||
|
||||
describe("resizeImageServer", () => {
|
||||
it("resizes image to specified dimensions", async () => {
|
||||
const { resizeImageServer } = await import("./image-processing-server");
|
||||
const buffer = Buffer.from("test-image-data");
|
||||
|
||||
const result = await resizeImageServer(buffer, 224, 224);
|
||||
expect(result).toBeInstanceOf(Buffer);
|
||||
expect(mockSharp).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("returns buffer for valid input", async () => {
|
||||
const { resizeImageServer } = await import("./image-processing-server");
|
||||
const buffer = Buffer.from("test-image-data");
|
||||
|
||||
const result = await resizeImageServer(buffer, 224, 224);
|
||||
expect(result).toBeInstanceOf(Buffer);
|
||||
});
|
||||
});
|
||||
51
src/lib/server/image-processing-server.ts
Normal file
51
src/lib/server/image-processing-server.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Server-only image processing helpers.
|
||||
*
|
||||
* These functions use Node.js native modules (sharp) and must NOT be
|
||||
* imported by client components. They are used exclusively by API
|
||||
* route handlers (server-side).
|
||||
*/
|
||||
|
||||
// ─── Resize ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Server-side image resize using sharp (if available) or a fallback.
|
||||
* This is used by the upload API route.
|
||||
*
|
||||
* @param buffer - Raw image buffer
|
||||
* @param size - Target dimension
|
||||
* @returns Promise<Buffer> resized image as JPEG
|
||||
*/
|
||||
export async function resizeImageServer(
|
||||
buffer: Buffer,
|
||||
size: number,
|
||||
): Promise<Buffer> {
|
||||
try {
|
||||
const sharpModule = await import("sharp");
|
||||
return sharpModule.default(buffer)
|
||||
.resize(size, size)
|
||||
.jpeg({ quality: 95 })
|
||||
.toBuffer();
|
||||
} catch {
|
||||
// Fallback: return original buffer if sharp is not available
|
||||
// In production, sharp should be installed
|
||||
throw new Error(
|
||||
"sharp is required for server-side image resizing. Install with: npm install sharp",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── MIME Type Helpers ───────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Generate a file extension from a MIME type.
|
||||
*/
|
||||
export function mimeTypeToExtension(mimeType: string): string {
|
||||
const map: Record<string, string> = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/jpg": "jpg",
|
||||
"image/webp": "webp",
|
||||
};
|
||||
return map[mimeType] || "jpg";
|
||||
}
|
||||
191
src/lib/types.ts
Normal file
191
src/lib/types.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
/**
|
||||
* Shared TypeScript interfaces for the Plant Disease Knowledge Base.
|
||||
* Used by seed data, API helpers, and API routes.
|
||||
*/
|
||||
|
||||
/** Types of causal agents that cause plant diseases */
|
||||
export type CausalAgentType = "fungal" | "bacterial" | "viral" | "environmental";
|
||||
|
||||
/** Severity levels for plant diseases */
|
||||
export type Severity = "low" | "moderate" | "high" | "critical";
|
||||
|
||||
/** How common/prevalent a disease is in the field */
|
||||
export type Prevalence = "common" | "uncommon" | "rare" | "very_rare";
|
||||
|
||||
/** Plant category for grouping and filtering */
|
||||
export type PlantCategory =
|
||||
| "vegetable"
|
||||
| "herb"
|
||||
| "houseplant"
|
||||
| "flower"
|
||||
| "fruit"
|
||||
| "succulent"
|
||||
| "tree";
|
||||
|
||||
/**
|
||||
* A plant entry in the knowledge base.
|
||||
* Each plant has 0+ associated diseases (linked via plantId in Disease).
|
||||
*/
|
||||
export interface Plant {
|
||||
/** Unique identifier (slug), e.g., "tomato", "monstera" */
|
||||
id: string;
|
||||
/** Common name, e.g., "Tomato" */
|
||||
commonName: string;
|
||||
/** Scientific (botanical) name, e.g., "Solanum lycopersicum" */
|
||||
scientificName: string;
|
||||
/** Plant family, e.g., "Solanaceae" */
|
||||
family: string;
|
||||
/** Category for filtering */
|
||||
category: PlantCategory;
|
||||
/** Brief care summary (light, water, humidity, temperature) */
|
||||
careSummary: string;
|
||||
/** URL to a representative image of the healthy plant */
|
||||
imageUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A disease entry in the knowledge base.
|
||||
* Links to a plant via plantId and can reference lookalike diseases.
|
||||
*/
|
||||
export interface Disease {
|
||||
/** Unique identifier (slug), e.g., "early-blight" */
|
||||
id: string;
|
||||
/** ID of the affected plant */
|
||||
plantId: string;
|
||||
/** Common disease name, e.g., "Early Blight" */
|
||||
name: string;
|
||||
/** Scientific name of the pathogen (if applicable) */
|
||||
scientificName: string;
|
||||
/** Type of causal agent */
|
||||
causalAgentType: CausalAgentType;
|
||||
/** Detailed description of the disease */
|
||||
description: string;
|
||||
/** Observable symptoms (≥3) */
|
||||
symptoms: string[];
|
||||
/** Root causes or contributing factors (≥2) */
|
||||
causes: string[];
|
||||
/** Step-by-step treatment instructions (≥3) */
|
||||
treatment: string[];
|
||||
/** Prevention tips (≥2) */
|
||||
prevention: string[];
|
||||
/** IDs of diseases that look similar and may be confused with this one */
|
||||
lookalikeDiseaseIds: string[];
|
||||
/** Overall severity of the disease */
|
||||
severity: Severity;
|
||||
/** How common/prevalent this disease is */
|
||||
prevalence: Prevalence;
|
||||
/** URL to a representative image showing disease symptoms */
|
||||
imageUrl?: string;
|
||||
}
|
||||
|
||||
/** Query parameters for listing/searching plants */
|
||||
export interface PlantListParams {
|
||||
search?: string;
|
||||
category?: PlantCategory;
|
||||
}
|
||||
|
||||
/** Query parameters for listing/searching diseases */
|
||||
export interface DiseaseListParams {
|
||||
plantId?: string;
|
||||
search?: string;
|
||||
causalAgentType?: CausalAgentType;
|
||||
severity?: Severity;
|
||||
}
|
||||
|
||||
/** Response wrapper for a single plant with its diseases */
|
||||
export interface PlantWithDiseases {
|
||||
plant: Plant;
|
||||
diseases: Disease[];
|
||||
}
|
||||
|
||||
/** Response wrapper for a single disease with its plant */
|
||||
export interface DiseaseWithPlant {
|
||||
disease: Disease;
|
||||
plant: Plant;
|
||||
}
|
||||
|
||||
/** Standard error response shape */
|
||||
export interface ApiError {
|
||||
error: string;
|
||||
message: string;
|
||||
status: number;
|
||||
}
|
||||
|
||||
/** Paginated list response (future-proof, currently returns all) */
|
||||
export interface PaginatedResponse<T> {
|
||||
items: T[];
|
||||
total: number;
|
||||
}
|
||||
|
||||
// ─── ML / Inference types ────────────────────────────────────────────────────
|
||||
|
||||
/** Confidence label based on calibrated score */
|
||||
export type ConfidenceLabel = "high" | "medium" | "low";
|
||||
|
||||
/** Raw prediction from model inference (before calibration) */
|
||||
export interface RawPrediction {
|
||||
/** Model output class index */
|
||||
classIndex: number;
|
||||
/** Raw softmax probability */
|
||||
probability: number;
|
||||
}
|
||||
|
||||
/** Calibrated confidence score with label */
|
||||
export interface ConfidenceResult {
|
||||
/** Raw softmax probability from model */
|
||||
raw: number;
|
||||
/** Calibrated/adjusted confidence score */
|
||||
adjusted: number;
|
||||
/** Human-readable confidence label */
|
||||
label: ConfidenceLabel;
|
||||
}
|
||||
|
||||
/** A single prediction in the identify API response */
|
||||
export interface PredictionResult {
|
||||
/** Disease ID matching knowledge base */
|
||||
diseaseId: string;
|
||||
/** Full disease object from knowledge base */
|
||||
disease: Disease;
|
||||
/** Calibrated confidence */
|
||||
confidence: ConfidenceResult;
|
||||
/** IDs of lookalike diseases that could be confused with this one */
|
||||
lookalikes: string[];
|
||||
/** Full disease objects for lookalikes, pre-resolved server-side */
|
||||
lookalikeDiseases: Disease[];
|
||||
/** The plant this disease affects (included for client convenience) */
|
||||
plant: Plant | null;
|
||||
}
|
||||
|
||||
/** Metadata about the inference run */
|
||||
export interface InferenceMetadata {
|
||||
/** Model identifier/version */
|
||||
model: string;
|
||||
/** Inference time in milliseconds */
|
||||
inferenceTimeMs: number;
|
||||
/** Image ID that was analyzed */
|
||||
imageId: string;
|
||||
}
|
||||
|
||||
/** Response from POST /api/identify */
|
||||
export interface IdentifyResponse {
|
||||
/** Ranked predictions */
|
||||
predictions: PredictionResult[];
|
||||
/** Inference metadata */
|
||||
metadata: InferenceMetadata;
|
||||
/** True when running in demo/mock mode (no real model loaded) */
|
||||
demo_mode?: boolean;
|
||||
}
|
||||
|
||||
/** Request body for POST /api/identify */
|
||||
export interface IdentifyRequest {
|
||||
/** Image ID from a previous /api/upload call */
|
||||
imageId: string;
|
||||
}
|
||||
|
||||
/** Result from runInference() */
|
||||
export interface InferenceResult {
|
||||
/** Top-K raw predictions sorted by probability descending */
|
||||
predictions: RawPrediction[];
|
||||
/** Inference time in milliseconds */
|
||||
inferenceTimeMs: number;
|
||||
}
|
||||
Reference in New Issue
Block a user