8.9 KiB
8.9 KiB
06. Plant-Context-Aware Identification
meta: id: production-ml-pipeline-06 feature: production-ml-pipeline priority: P2 depends_on: [production-ml-pipeline-05] tags: [implementation, ux, tests-required]
objective:
- Allow users to optionally specify which plant they're diagnosing before identification
- Boost predictions for the selected plant's diseases (multiply confidence by plant-context factor)
- Update the upload flow to include optional plant selection
- Improve prediction accuracy when plant context is known
deliverables:
src/app/api/identify/route.ts— accept optionalplantIdparametersrc/lib/ml/plant-context.ts— new module for plant-context scoring adjustmentsrc/components/PlantSelector.tsx— new component for optional plant selectionsrc/app/upload/page.tsx— integrate PlantSelector before uploadsrc/lib/api/identify.ts— client API updated to pass plantId
steps:
-
Create plant-context scoring module
src/lib/ml/plant-context.ts:import { PLANTVILLAGE_CLASSES } from "./plantvillage-classes"; /** * Adjust prediction scores based on plant context. * If plantId is provided, boost predictions for diseases of that plant. * * @param predictions - Top-K predictions with classIndex and probability * @param plantId - Optional plant ID from user selection * @param boostFactor - Multiplier for matching plant diseases (default 1.5) * @returns Adjusted predictions with updated probabilities */ export function applyPlantContext( predictions: Array<{ classIndex: number; probability: number }>, plantId: string | null, boostFactor: number = 1.5, ): Array<{ classIndex: number; probability: number; contextBoosted: boolean }> { if (!plantId) { return predictions.map((p) => ({ ...p, contextBoosted: false })); } // Find which class indices belong to this plant const plantIndices = new Set( PLANTVILLAGE_CLASSES.filter((c) => c.plantId === plantId && !c.isHealthy).map( (c) => c.index, ), ); return predictions.map((pred) => { const matchesPlant = plantIndices.has(pred.classIndex); return { classIndex: pred.classIndex, probability: matchesPlant ? Math.min(1.0, pred.probability * boostFactor) : pred.probability, contextBoosted: matchesPlant, }; }); } -
Update
/api/identifyroute to acceptplantId:export async function POST(request: NextRequest) { const body = await request.json(); const { imageId, plantId } = body; // plantId is optional // ... existing preprocessing ... const { probabilities, inferenceTimeMs } = await runInference(tensor); // Get top-K predictions const topK = getTopKFloat32(probabilities, 5); // Apply plant context if provided const adjusted = applyPlantContext(topK, plantId ?? null); // Enrich with KB data const predictions = await enrichPredictions(adjusted); return NextResponse.json({ predictions, metadata: { model: MODEL_ID, inferenceTimeMs, imageId, plantContext: plantId ?? null }, }); } -
Update
IdentifyRequesttype:// src/lib/types.ts export interface IdentifyRequest { imageId: string; plantId?: string; // Optional plant context } -
Create
PlantSelectorcomponentsrc/components/PlantSelector.tsx:"use client"; import { useState, useEffect } from "react"; interface Plant { id: string; commonName: string; imageUrl?: string; } export default function PlantSelector({ value, onChange, }: { value: string | null; onChange: (plantId: string | null) => void; }) { const [plants, setPlants] = useState<Plant[]>([]); const [search, setSearch] = useState(""); useEffect(() => { fetch("/api/plants?limit=50") .then((r) => r.json()) .then((data) => setPlants(data.items ?? [])); }, []); const filtered = plants.filter((p) => p.commonName.toLowerCase().includes(search.toLowerCase()), ); return ( <div className="..."> <label>Plant (optional)</label> <input type="text" placeholder="Search plants..." value={search} onChange={(e) => setSearch(e.target.value)} /> {value && ( <div className="..."> Selected: {plants.find((p) => p.id === value)?.commonName} <button onClick={() => onChange(null)}>Clear</button> </div> )} <ul> {filtered.slice(0, 10).map((plant) => ( <li key={plant.id} onClick={() => onChange(plant.id)}> {plant.commonName} </li> ))} </ul> </div> ); } -
Update upload page to include plant selector:
// src/app/upload/page.tsx export default function UploadPage() { const [selectedPlant, setSelectedPlant] = useState<string | null>(null); const handleUpload = useCallback( async (file: File) => { // 1. Upload image const uploadResponse = await uploadImage(file); // 2. Identify with plant context const identifyResponse = await identifyPlant(uploadResponse.imageId, selectedPlant); // 3. Navigate to results router.push(`/results/${uploadResponse.imageId}`); }, [selectedPlant], ); return ( <div> <PlantSelector value={selectedPlant} onChange={setSelectedPlant} /> <ImageUpload onUpload={handleUpload} /> </div> ); } -
Update client-side API to pass plantId:
// src/lib/api/identify.ts export async function identifyPlant( imageId: string, plantId?: string, ): Promise<IdentifyResponse> { const body: IdentifyRequest = { imageId }; if (plantId) body.plantId = plantId; const response = await fetch("/api/identify", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify(body), }); return response.json(); } -
Update
PredictionResulttype to include context boost info:export interface PredictionResult { diseaseId: string; disease: Disease; confidence: ConfidenceResult; lookalikes: string[]; plant: Plant | null; contextBoosted?: boolean; // true if boosted by plant context } -
Update
ResultsDashboardto show context boost indicator:{ prediction.contextBoosted && ( <span className="text-xs text-leaf-green-600">✓ Matches selected plant</span> ); } -
Store plant context in results page — pass plantId through URL or state:
// src/app/results/[imageId]/page.tsx const plantId = searchParams.get("plant"); // optional const response = await identifyPlant(imageId, plantId);
tests:
- Unit:
applyPlantContext()with no plantId returns predictions unchanged - Unit:
applyPlantContext()with plantId="tomato" boosts tomato disease predictions - Unit: boosted probabilities are capped at 1.0
- Unit: non-matching plant predictions are unchanged
- Unit:
contextBoostedflag is set correctly - Integration: POST
/api/identifywith plantId returns boosted predictions - Integration: POST
/api/identifywithout plantId returns normal predictions - E2E: select "Tomato" in UI → upload tomato leaf → tomato diseases appear first
acceptance_criteria:
- Plant context is optional — identification works without it
- When plantId is provided, predictions for that plant's diseases are boosted by 1.5x
- Boosted probabilities are capped at 1.0
contextBoostedflag is set on boosted predictions- UI shows "Matches selected plant" indicator on boosted predictions
- Plant selector component works (search, select, clear)
- Upload flow includes optional plant selection step
- Results page receives and displays plant context
validation:
npx vitest run src/lib/ml/plant-context.test.tsnpx vitest run src/components/PlantSelector.test.tsx- Manual: select "Tomato" → upload image → tomato diseases appear with boost indicator
- Manual: don't select plant → upload image → normal predictions (no boost)
- Check API response:
predictions[0].contextBoostedis true when plant matches
notes:
- Plant context is a scoring heuristic, not a hard filter. It boosts confidence but doesn't exclude other predictions.
- The default boost factor is 1.5 — this can be tuned based on user feedback.
- Plant selector is optional — users can skip it and get unboosted predictions.
- The plant context feature is most useful when the user knows what plant they're diagnosing but the model is uncertain between multiple diseases.
- For PlantVillage, each plant has 1–9 diseases, so the boost is specific enough to be useful without being overly restrictive.