beepboop
This commit is contained in:
@@ -0,0 +1,284 @@
|
||||
# 06. Plant-Context-Aware Identification
|
||||
|
||||
meta:
|
||||
id: production-ml-pipeline-06
|
||||
feature: production-ml-pipeline
|
||||
priority: P2
|
||||
depends_on: [production-ml-pipeline-05]
|
||||
tags: [implementation, ux, tests-required]
|
||||
|
||||
objective:
|
||||
|
||||
- Allow users to optionally specify which plant they're diagnosing before identification
|
||||
- Boost predictions for the selected plant's diseases (multiply confidence by plant-context factor)
|
||||
- Update the upload flow to include optional plant selection
|
||||
- Improve prediction accuracy when plant context is known
|
||||
|
||||
deliverables:
|
||||
|
||||
- `src/app/api/identify/route.ts` — accept optional `plantId` parameter
|
||||
- `src/lib/ml/plant-context.ts` — new module for plant-context scoring adjustment
|
||||
- `src/components/PlantSelector.tsx` — new component for optional plant selection
|
||||
- `src/app/upload/page.tsx` — integrate PlantSelector before upload
|
||||
- `src/lib/api/identify.ts` — client API updated to pass plantId
|
||||
|
||||
steps:
|
||||
|
||||
1. **Create plant-context scoring module** `src/lib/ml/plant-context.ts`:
|
||||
|
||||
```typescript
|
||||
import { PLANTVILLAGE_CLASSES } from "./plantvillage-classes";
|
||||
|
||||
/**
|
||||
* Adjust prediction scores based on plant context.
|
||||
* If plantId is provided, boost predictions for diseases of that plant.
|
||||
*
|
||||
* @param predictions - Top-K predictions with classIndex and probability
|
||||
* @param plantId - Optional plant ID from user selection
|
||||
* @param boostFactor - Multiplier for matching plant diseases (default 1.5)
|
||||
* @returns Adjusted predictions with updated probabilities
|
||||
*/
|
||||
export function applyPlantContext(
|
||||
predictions: Array<{ classIndex: number; probability: number }>,
|
||||
plantId: string | null,
|
||||
boostFactor: number = 1.5,
|
||||
): Array<{ classIndex: number; probability: number; contextBoosted: boolean }> {
|
||||
if (!plantId) {
|
||||
return predictions.map((p) => ({ ...p, contextBoosted: false }));
|
||||
}
|
||||
|
||||
// Find which class indices belong to this plant
|
||||
const plantIndices = new Set(
|
||||
PLANTVILLAGE_CLASSES.filter((c) => c.plantId === plantId && !c.isHealthy).map(
|
||||
(c) => c.index,
|
||||
),
|
||||
);
|
||||
|
||||
return predictions.map((pred) => {
|
||||
const matchesPlant = plantIndices.has(pred.classIndex);
|
||||
return {
|
||||
classIndex: pred.classIndex,
|
||||
probability: matchesPlant
|
||||
? Math.min(1.0, pred.probability * boostFactor)
|
||||
: pred.probability,
|
||||
contextBoosted: matchesPlant,
|
||||
};
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
2. **Update `/api/identify` route** to accept `plantId`:
|
||||
|
||||
```typescript
|
||||
export async function POST(request: NextRequest) {
|
||||
const body = await request.json();
|
||||
const { imageId, plantId } = body; // plantId is optional
|
||||
|
||||
// ... existing preprocessing ...
|
||||
|
||||
const { probabilities, inferenceTimeMs } = await runInference(tensor);
|
||||
|
||||
// Get top-K predictions
|
||||
const topK = getTopKFloat32(probabilities, 5);
|
||||
|
||||
// Apply plant context if provided
|
||||
const adjusted = applyPlantContext(topK, plantId ?? null);
|
||||
|
||||
// Enrich with KB data
|
||||
const predictions = await enrichPredictions(adjusted);
|
||||
|
||||
return NextResponse.json({
|
||||
predictions,
|
||||
metadata: { model: MODEL_ID, inferenceTimeMs, imageId, plantContext: plantId ?? null },
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
3. **Update `IdentifyRequest` type**:
|
||||
|
||||
```typescript
|
||||
// src/lib/types.ts
|
||||
export interface IdentifyRequest {
|
||||
imageId: string;
|
||||
plantId?: string; // Optional plant context
|
||||
}
|
||||
```
|
||||
|
||||
4. **Create `PlantSelector` component** `src/components/PlantSelector.tsx`:
|
||||
|
||||
```tsx
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
interface Plant {
|
||||
id: string;
|
||||
commonName: string;
|
||||
imageUrl?: string;
|
||||
}
|
||||
|
||||
export default function PlantSelector({
|
||||
value,
|
||||
onChange,
|
||||
}: {
|
||||
value: string | null;
|
||||
onChange: (plantId: string | null) => void;
|
||||
}) {
|
||||
const [plants, setPlants] = useState<Plant[]>([]);
|
||||
const [search, setSearch] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
fetch("/api/plants?limit=50")
|
||||
.then((r) => r.json())
|
||||
.then((data) => setPlants(data.items ?? []));
|
||||
}, []);
|
||||
|
||||
const filtered = plants.filter((p) =>
|
||||
p.commonName.toLowerCase().includes(search.toLowerCase()),
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="...">
|
||||
<label>Plant (optional)</label>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search plants..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
/>
|
||||
{value && (
|
||||
<div className="...">
|
||||
Selected: {plants.find((p) => p.id === value)?.commonName}
|
||||
<button onClick={() => onChange(null)}>Clear</button>
|
||||
</div>
|
||||
)}
|
||||
<ul>
|
||||
{filtered.slice(0, 10).map((plant) => (
|
||||
<li key={plant.id} onClick={() => onChange(plant.id)}>
|
||||
{plant.commonName}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
5. **Update upload page** to include plant selector:
|
||||
|
||||
```tsx
|
||||
// src/app/upload/page.tsx
|
||||
export default function UploadPage() {
|
||||
const [selectedPlant, setSelectedPlant] = useState<string | null>(null);
|
||||
|
||||
const handleUpload = useCallback(
|
||||
async (file: File) => {
|
||||
// 1. Upload image
|
||||
const uploadResponse = await uploadImage(file);
|
||||
|
||||
// 2. Identify with plant context
|
||||
const identifyResponse = await identifyPlant(uploadResponse.imageId, selectedPlant);
|
||||
|
||||
// 3. Navigate to results
|
||||
router.push(`/results/${uploadResponse.imageId}`);
|
||||
},
|
||||
[selectedPlant],
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<PlantSelector value={selectedPlant} onChange={setSelectedPlant} />
|
||||
<ImageUpload onUpload={handleUpload} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
6. **Update client-side API** to pass plantId:
|
||||
|
||||
```typescript
|
||||
// src/lib/api/identify.ts
|
||||
export async function identifyPlant(
|
||||
imageId: string,
|
||||
plantId?: string,
|
||||
): Promise<IdentifyResponse> {
|
||||
const body: IdentifyRequest = { imageId };
|
||||
if (plantId) body.plantId = plantId;
|
||||
|
||||
const response = await fetch("/api/identify", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
```
|
||||
|
||||
7. **Update `PredictionResult` type** to include context boost info:
|
||||
|
||||
```typescript
|
||||
export interface PredictionResult {
|
||||
diseaseId: string;
|
||||
disease: Disease;
|
||||
confidence: ConfidenceResult;
|
||||
lookalikes: string[];
|
||||
plant: Plant | null;
|
||||
contextBoosted?: boolean; // true if boosted by plant context
|
||||
}
|
||||
```
|
||||
|
||||
8. **Update `ResultsDashboard`** to show context boost indicator:
|
||||
|
||||
```tsx
|
||||
{
|
||||
prediction.contextBoosted && (
|
||||
<span className="text-xs text-leaf-green-600">✓ Matches selected plant</span>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
9. **Store plant context in results page** — pass plantId through URL or state:
|
||||
```typescript
|
||||
// src/app/results/[imageId]/page.tsx
|
||||
const plantId = searchParams.get("plant"); // optional
|
||||
const response = await identifyPlant(imageId, plantId);
|
||||
```
|
||||
|
||||
tests:
|
||||
|
||||
- Unit: `applyPlantContext()` with no plantId returns predictions unchanged
|
||||
- Unit: `applyPlantContext()` with plantId="tomato" boosts tomato disease predictions
|
||||
- Unit: boosted probabilities are capped at 1.0
|
||||
- Unit: non-matching plant predictions are unchanged
|
||||
- Unit: `contextBoosted` flag is set correctly
|
||||
- Integration: POST `/api/identify` with plantId returns boosted predictions
|
||||
- Integration: POST `/api/identify` without plantId returns normal predictions
|
||||
- E2E: select "Tomato" in UI → upload tomato leaf → tomato diseases appear first
|
||||
|
||||
acceptance_criteria:
|
||||
|
||||
- Plant context is optional — identification works without it
|
||||
- When plantId is provided, predictions for that plant's diseases are boosted by 1.5x
|
||||
- Boosted probabilities are capped at 1.0
|
||||
- `contextBoosted` flag is set on boosted predictions
|
||||
- UI shows "Matches selected plant" indicator on boosted predictions
|
||||
- Plant selector component works (search, select, clear)
|
||||
- Upload flow includes optional plant selection step
|
||||
- Results page receives and displays plant context
|
||||
|
||||
validation:
|
||||
|
||||
- `npx vitest run src/lib/ml/plant-context.test.ts`
|
||||
- `npx vitest run src/components/PlantSelector.test.tsx`
|
||||
- Manual: select "Tomato" → upload image → tomato diseases appear with boost indicator
|
||||
- Manual: don't select plant → upload image → normal predictions (no boost)
|
||||
- Check API response: `predictions[0].contextBoosted` is true when plant matches
|
||||
|
||||
notes:
|
||||
|
||||
- Plant context is a scoring heuristic, not a hard filter. It boosts confidence but doesn't exclude other predictions.
|
||||
- The default boost factor is 1.5 — this can be tuned based on user feedback.
|
||||
- Plant selector is optional — users can skip it and get unboosted predictions.
|
||||
- The plant context feature is most useful when the user knows what plant they're diagnosing but the model is uncertain between multiple diseases.
|
||||
- For PlantVillage, each plant has 1–9 diseases, so the boost is specific enough to be useful without being overly restrictive.
|
||||
Reference in New Issue
Block a user