285 lines
8.9 KiB
Markdown
285 lines
8.9 KiB
Markdown
# 06. Plant-Context-Aware Identification
|
||
|
||
meta:
|
||
id: production-ml-pipeline-06
|
||
feature: production-ml-pipeline
|
||
priority: P2
|
||
depends_on: [production-ml-pipeline-05]
|
||
tags: [implementation, ux, tests-required]
|
||
|
||
objective:
|
||
|
||
- Allow users to optionally specify which plant they're diagnosing before identification
|
||
- Boost predictions for the selected plant's diseases (multiply confidence by plant-context factor)
|
||
- Update the upload flow to include optional plant selection
|
||
- Improve prediction accuracy when plant context is known
|
||
|
||
deliverables:
|
||
|
||
- `src/app/api/identify/route.ts` — accept optional `plantId` parameter
|
||
- `src/lib/ml/plant-context.ts` — new module for plant-context scoring adjustment
|
||
- `src/components/PlantSelector.tsx` — new component for optional plant selection
|
||
- `src/app/upload/page.tsx` — integrate PlantSelector before upload
|
||
- `src/lib/api/identify.ts` — client API updated to pass plantId
|
||
|
||
steps:
|
||
|
||
1. **Create plant-context scoring module** `src/lib/ml/plant-context.ts`:
|
||
|
||
```typescript
|
||
import { PLANTVILLAGE_CLASSES } from "./plantvillage-classes";
|
||
|
||
/**
|
||
* Adjust prediction scores based on plant context.
|
||
* If plantId is provided, boost predictions for diseases of that plant.
|
||
*
|
||
* @param predictions - Top-K predictions with classIndex and probability
|
||
* @param plantId - Optional plant ID from user selection
|
||
* @param boostFactor - Multiplier for matching plant diseases (default 1.5)
|
||
* @returns Adjusted predictions with updated probabilities
|
||
*/
|
||
export function applyPlantContext(
|
||
predictions: Array<{ classIndex: number; probability: number }>,
|
||
plantId: string | null,
|
||
boostFactor: number = 1.5,
|
||
): Array<{ classIndex: number; probability: number; contextBoosted: boolean }> {
|
||
if (!plantId) {
|
||
return predictions.map((p) => ({ ...p, contextBoosted: false }));
|
||
}
|
||
|
||
// Find which class indices belong to this plant
|
||
const plantIndices = new Set(
|
||
PLANTVILLAGE_CLASSES.filter((c) => c.plantId === plantId && !c.isHealthy).map(
|
||
(c) => c.index,
|
||
),
|
||
);
|
||
|
||
return predictions.map((pred) => {
|
||
const matchesPlant = plantIndices.has(pred.classIndex);
|
||
return {
|
||
classIndex: pred.classIndex,
|
||
probability: matchesPlant
|
||
? Math.min(1.0, pred.probability * boostFactor)
|
||
: pred.probability,
|
||
contextBoosted: matchesPlant,
|
||
};
|
||
});
|
||
}
|
||
```
|
||
|
||
2. **Update `/api/identify` route** to accept `plantId`:
|
||
|
||
```typescript
|
||
export async function POST(request: NextRequest) {
|
||
const body = await request.json();
|
||
const { imageId, plantId } = body; // plantId is optional
|
||
|
||
// ... existing preprocessing ...
|
||
|
||
const { probabilities, inferenceTimeMs } = await runInference(tensor);
|
||
|
||
// Get top-K predictions
|
||
const topK = getTopKFloat32(probabilities, 5);
|
||
|
||
// Apply plant context if provided
|
||
const adjusted = applyPlantContext(topK, plantId ?? null);
|
||
|
||
// Enrich with KB data
|
||
const predictions = await enrichPredictions(adjusted);
|
||
|
||
return NextResponse.json({
|
||
predictions,
|
||
metadata: { model: MODEL_ID, inferenceTimeMs, imageId, plantContext: plantId ?? null },
|
||
});
|
||
}
|
||
```
|
||
|
||
3. **Update `IdentifyRequest` type**:
|
||
|
||
```typescript
|
||
// src/lib/types.ts
|
||
export interface IdentifyRequest {
|
||
imageId: string;
|
||
plantId?: string; // Optional plant context
|
||
}
|
||
```
|
||
|
||
4. **Create `PlantSelector` component** `src/components/PlantSelector.tsx`:
|
||
|
||
```tsx
|
||
"use client";
|
||
|
||
import { useState, useEffect } from "react";
|
||
|
||
interface Plant {
|
||
id: string;
|
||
commonName: string;
|
||
imageUrl?: string;
|
||
}
|
||
|
||
export default function PlantSelector({
|
||
value,
|
||
onChange,
|
||
}: {
|
||
value: string | null;
|
||
onChange: (plantId: string | null) => void;
|
||
}) {
|
||
const [plants, setPlants] = useState<Plant[]>([]);
|
||
const [search, setSearch] = useState("");
|
||
|
||
useEffect(() => {
|
||
fetch("/api/plants?limit=50")
|
||
.then((r) => r.json())
|
||
.then((data) => setPlants(data.items ?? []));
|
||
}, []);
|
||
|
||
const filtered = plants.filter((p) =>
|
||
p.commonName.toLowerCase().includes(search.toLowerCase()),
|
||
);
|
||
|
||
return (
|
||
<div className="...">
|
||
<label>Plant (optional)</label>
|
||
<input
|
||
type="text"
|
||
placeholder="Search plants..."
|
||
value={search}
|
||
onChange={(e) => setSearch(e.target.value)}
|
||
/>
|
||
{value && (
|
||
<div className="...">
|
||
Selected: {plants.find((p) => p.id === value)?.commonName}
|
||
<button onClick={() => onChange(null)}>Clear</button>
|
||
</div>
|
||
)}
|
||
<ul>
|
||
{filtered.slice(0, 10).map((plant) => (
|
||
<li key={plant.id} onClick={() => onChange(plant.id)}>
|
||
{plant.commonName}
|
||
</li>
|
||
))}
|
||
</ul>
|
||
</div>
|
||
);
|
||
}
|
||
```
|
||
|
||
5. **Update upload page** to include plant selector:
|
||
|
||
```tsx
|
||
// src/app/upload/page.tsx
|
||
export default function UploadPage() {
|
||
const [selectedPlant, setSelectedPlant] = useState<string | null>(null);
|
||
|
||
const handleUpload = useCallback(
|
||
async (file: File) => {
|
||
// 1. Upload image
|
||
const uploadResponse = await uploadImage(file);
|
||
|
||
// 2. Identify with plant context
|
||
const identifyResponse = await identifyPlant(uploadResponse.imageId, selectedPlant);
|
||
|
||
// 3. Navigate to results
|
||
router.push(`/results/${uploadResponse.imageId}`);
|
||
},
|
||
[selectedPlant],
|
||
);
|
||
|
||
return (
|
||
<div>
|
||
<PlantSelector value={selectedPlant} onChange={setSelectedPlant} />
|
||
<ImageUpload onUpload={handleUpload} />
|
||
</div>
|
||
);
|
||
}
|
||
```
|
||
|
||
6. **Update client-side API** to pass plantId:
|
||
|
||
```typescript
|
||
// src/lib/api/identify.ts
|
||
export async function identifyPlant(
|
||
imageId: string,
|
||
plantId?: string,
|
||
): Promise<IdentifyResponse> {
|
||
const body: IdentifyRequest = { imageId };
|
||
if (plantId) body.plantId = plantId;
|
||
|
||
const response = await fetch("/api/identify", {
|
||
method: "POST",
|
||
headers: { "Content-Type": "application/json" },
|
||
body: JSON.stringify(body),
|
||
});
|
||
|
||
return response.json();
|
||
}
|
||
```
|
||
|
||
7. **Update `PredictionResult` type** to include context boost info:
|
||
|
||
```typescript
|
||
export interface PredictionResult {
|
||
diseaseId: string;
|
||
disease: Disease;
|
||
confidence: ConfidenceResult;
|
||
lookalikes: string[];
|
||
plant: Plant | null;
|
||
contextBoosted?: boolean; // true if boosted by plant context
|
||
}
|
||
```
|
||
|
||
8. **Update `ResultsDashboard`** to show context boost indicator:
|
||
|
||
```tsx
|
||
{
|
||
prediction.contextBoosted && (
|
||
<span className="text-xs text-leaf-green-600">✓ Matches selected plant</span>
|
||
);
|
||
}
|
||
```
|
||
|
||
9. **Store plant context in results page** — pass plantId through URL or state:
|
||
```typescript
|
||
// src/app/results/[imageId]/page.tsx
|
||
const plantId = searchParams.get("plant"); // optional
|
||
const response = await identifyPlant(imageId, plantId);
|
||
```
|
||
|
||
tests:
|
||
|
||
- Unit: `applyPlantContext()` with no plantId returns predictions unchanged
|
||
- Unit: `applyPlantContext()` with plantId="tomato" boosts tomato disease predictions
|
||
- Unit: boosted probabilities are capped at 1.0
|
||
- Unit: non-matching plant predictions are unchanged
|
||
- Unit: `contextBoosted` flag is set correctly
|
||
- Integration: POST `/api/identify` with plantId returns boosted predictions
|
||
- Integration: POST `/api/identify` without plantId returns normal predictions
|
||
- E2E: select "Tomato" in UI → upload tomato leaf → tomato diseases appear first
|
||
|
||
acceptance_criteria:
|
||
|
||
- Plant context is optional — identification works without it
|
||
- When plantId is provided, predictions for that plant's diseases are boosted by 1.5x
|
||
- Boosted probabilities are capped at 1.0
|
||
- `contextBoosted` flag is set on boosted predictions
|
||
- UI shows "Matches selected plant" indicator on boosted predictions
|
||
- Plant selector component works (search, select, clear)
|
||
- Upload flow includes optional plant selection step
|
||
- Results page receives and displays plant context
|
||
|
||
validation:
|
||
|
||
- `npx vitest run src/lib/ml/plant-context.test.ts`
|
||
- `npx vitest run src/components/PlantSelector.test.tsx`
|
||
- Manual: select "Tomato" → upload image → tomato diseases appear with boost indicator
|
||
- Manual: don't select plant → upload image → normal predictions (no boost)
|
||
- Check API response: `predictions[0].contextBoosted` is true when plant matches
|
||
|
||
notes:
|
||
|
||
- Plant context is a scoring heuristic, not a hard filter. It boosts confidence but doesn't exclude other predictions.
|
||
- The default boost factor is 1.5 — this can be tuned based on user feedback.
|
||
- Plant selector is optional — users can skip it and get unboosted predictions.
|
||
- The plant context feature is most useful when the user knows what plant they're diagnosing but the model is uncertain between multiple diseases.
|
||
- For PlantVillage, each plant has 1–9 diseases, so the boost is specific enough to be useful without being overly restrictive.
|