Files
plant-disease-id/tasks/production-ml-pipeline/06-plant-context-identification.md
2026-06-08 16:42:04 -04:00

285 lines
8.9 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 06. Plant-Context-Aware Identification
meta:
id: production-ml-pipeline-06
feature: production-ml-pipeline
priority: P2
depends_on: [production-ml-pipeline-05]
tags: [implementation, ux, tests-required]
objective:
- Allow users to optionally specify which plant they're diagnosing before identification
- Boost predictions for the selected plant's diseases (multiply confidence by plant-context factor)
- Update the upload flow to include optional plant selection
- Improve prediction accuracy when plant context is known
deliverables:
- `src/app/api/identify/route.ts` — accept optional `plantId` parameter
- `src/lib/ml/plant-context.ts` — new module for plant-context scoring adjustment
- `src/components/PlantSelector.tsx` — new component for optional plant selection
- `src/app/upload/page.tsx` — integrate PlantSelector before upload
- `src/lib/api/identify.ts` — client API updated to pass plantId
steps:
1. **Create plant-context scoring module** `src/lib/ml/plant-context.ts`:
```typescript
import { PLANTVILLAGE_CLASSES } from "./plantvillage-classes";
/**
* Adjust prediction scores based on plant context.
* If plantId is provided, boost predictions for diseases of that plant.
*
* @param predictions - Top-K predictions with classIndex and probability
* @param plantId - Optional plant ID from user selection
* @param boostFactor - Multiplier for matching plant diseases (default 1.5)
* @returns Adjusted predictions with updated probabilities
*/
export function applyPlantContext(
predictions: Array<{ classIndex: number; probability: number }>,
plantId: string | null,
boostFactor: number = 1.5,
): Array<{ classIndex: number; probability: number; contextBoosted: boolean }> {
if (!plantId) {
return predictions.map((p) => ({ ...p, contextBoosted: false }));
}
// Find which class indices belong to this plant
const plantIndices = new Set(
PLANTVILLAGE_CLASSES.filter((c) => c.plantId === plantId && !c.isHealthy).map(
(c) => c.index,
),
);
return predictions.map((pred) => {
const matchesPlant = plantIndices.has(pred.classIndex);
return {
classIndex: pred.classIndex,
probability: matchesPlant
? Math.min(1.0, pred.probability * boostFactor)
: pred.probability,
contextBoosted: matchesPlant,
};
});
}
```
2. **Update `/api/identify` route** to accept `plantId`:
```typescript
export async function POST(request: NextRequest) {
const body = await request.json();
const { imageId, plantId } = body; // plantId is optional
// ... existing preprocessing ...
const { probabilities, inferenceTimeMs } = await runInference(tensor);
// Get top-K predictions
const topK = getTopKFloat32(probabilities, 5);
// Apply plant context if provided
const adjusted = applyPlantContext(topK, plantId ?? null);
// Enrich with KB data
const predictions = await enrichPredictions(adjusted);
return NextResponse.json({
predictions,
metadata: { model: MODEL_ID, inferenceTimeMs, imageId, plantContext: plantId ?? null },
});
}
```
3. **Update `IdentifyRequest` type**:
```typescript
// src/lib/types.ts
export interface IdentifyRequest {
imageId: string;
plantId?: string; // Optional plant context
}
```
4. **Create `PlantSelector` component** `src/components/PlantSelector.tsx`:
```tsx
"use client";
import { useState, useEffect } from "react";
interface Plant {
id: string;
commonName: string;
imageUrl?: string;
}
export default function PlantSelector({
value,
onChange,
}: {
value: string | null;
onChange: (plantId: string | null) => void;
}) {
const [plants, setPlants] = useState<Plant[]>([]);
const [search, setSearch] = useState("");
useEffect(() => {
fetch("/api/plants?limit=50")
.then((r) => r.json())
.then((data) => setPlants(data.items ?? []));
}, []);
const filtered = plants.filter((p) =>
p.commonName.toLowerCase().includes(search.toLowerCase()),
);
return (
<div className="...">
<label>Plant (optional)</label>
<input
type="text"
placeholder="Search plants..."
value={search}
onChange={(e) => setSearch(e.target.value)}
/>
{value && (
<div className="...">
Selected: {plants.find((p) => p.id === value)?.commonName}
<button onClick={() => onChange(null)}>Clear</button>
</div>
)}
<ul>
{filtered.slice(0, 10).map((plant) => (
<li key={plant.id} onClick={() => onChange(plant.id)}>
{plant.commonName}
</li>
))}
</ul>
</div>
);
}
```
5. **Update upload page** to include plant selector:
```tsx
// src/app/upload/page.tsx
export default function UploadPage() {
const [selectedPlant, setSelectedPlant] = useState<string | null>(null);
const handleUpload = useCallback(
async (file: File) => {
// 1. Upload image
const uploadResponse = await uploadImage(file);
// 2. Identify with plant context
const identifyResponse = await identifyPlant(uploadResponse.imageId, selectedPlant);
// 3. Navigate to results
router.push(`/results/${uploadResponse.imageId}`);
},
[selectedPlant],
);
return (
<div>
<PlantSelector value={selectedPlant} onChange={setSelectedPlant} />
<ImageUpload onUpload={handleUpload} />
</div>
);
}
```
6. **Update client-side API** to pass plantId:
```typescript
// src/lib/api/identify.ts
export async function identifyPlant(
imageId: string,
plantId?: string,
): Promise<IdentifyResponse> {
const body: IdentifyRequest = { imageId };
if (plantId) body.plantId = plantId;
const response = await fetch("/api/identify", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});
return response.json();
}
```
7. **Update `PredictionResult` type** to include context boost info:
```typescript
export interface PredictionResult {
diseaseId: string;
disease: Disease;
confidence: ConfidenceResult;
lookalikes: string[];
plant: Plant | null;
contextBoosted?: boolean; // true if boosted by plant context
}
```
8. **Update `ResultsDashboard`** to show context boost indicator:
```tsx
{
prediction.contextBoosted && (
<span className="text-xs text-leaf-green-600">✓ Matches selected plant</span>
);
}
```
9. **Store plant context in results page** — pass plantId through URL or state:
```typescript
// src/app/results/[imageId]/page.tsx
const plantId = searchParams.get("plant"); // optional
const response = await identifyPlant(imageId, plantId);
```
tests:
- Unit: `applyPlantContext()` with no plantId returns predictions unchanged
- Unit: `applyPlantContext()` with plantId="tomato" boosts tomato disease predictions
- Unit: boosted probabilities are capped at 1.0
- Unit: non-matching plant predictions are unchanged
- Unit: `contextBoosted` flag is set correctly
- Integration: POST `/api/identify` with plantId returns boosted predictions
- Integration: POST `/api/identify` without plantId returns normal predictions
- E2E: select "Tomato" in UI → upload tomato leaf → tomato diseases appear first
acceptance_criteria:
- Plant context is optional — identification works without it
- When plantId is provided, predictions for that plant's diseases are boosted by 1.5x
- Boosted probabilities are capped at 1.0
- `contextBoosted` flag is set on boosted predictions
- UI shows "Matches selected plant" indicator on boosted predictions
- Plant selector component works (search, select, clear)
- Upload flow includes optional plant selection step
- Results page receives and displays plant context
validation:
- `npx vitest run src/lib/ml/plant-context.test.ts`
- `npx vitest run src/components/PlantSelector.test.tsx`
- Manual: select "Tomato" → upload image → tomato diseases appear with boost indicator
- Manual: don't select plant → upload image → normal predictions (no boost)
- Check API response: `predictions[0].contextBoosted` is true when plant matches
notes:
- Plant context is a scoring heuristic, not a hard filter. It boosts confidence but doesn't exclude other predictions.
- The default boost factor is 1.5 — this can be tuned based on user feedback.
- Plant selector is optional — users can skip it and get unboosted predictions.
- The plant context feature is most useful when the user knows what plant they're diagnosing but the model is uncertain between multiple diseases.
- For PlantVillage, each plant has 19 diseases, so the boost is specific enough to be useful without being overly restrictive.