Files
plant-disease-id/apps/web/tasks/production-ml-pipeline/06-plant-context-identification.md
2026-06-06 15:09:46 -04:00

8.9 KiB
Raw Blame History

06. Plant-Context-Aware Identification

meta: id: production-ml-pipeline-06 feature: production-ml-pipeline priority: P2 depends_on: [production-ml-pipeline-05] tags: [implementation, ux, tests-required]

objective:

  • Allow users to optionally specify which plant they're diagnosing before identification
  • Boost predictions for the selected plant's diseases (multiply confidence by plant-context factor)
  • Update the upload flow to include optional plant selection
  • Improve prediction accuracy when plant context is known

deliverables:

  • src/app/api/identify/route.ts — accept optional plantId parameter
  • src/lib/ml/plant-context.ts — new module for plant-context scoring adjustment
  • src/components/PlantSelector.tsx — new component for optional plant selection
  • src/app/upload/page.tsx — integrate PlantSelector before upload
  • src/lib/api/identify.ts — client API updated to pass plantId

steps:

  1. Create plant-context scoring module src/lib/ml/plant-context.ts:

    import { PLANTVILLAGE_CLASSES } from "./plantvillage-classes";
    
    /**
     * Adjust prediction scores based on plant context.
     * If plantId is provided, boost predictions for diseases of that plant.
     *
     * @param predictions - Top-K predictions with classIndex and probability
     * @param plantId - Optional plant ID from user selection
     * @param boostFactor - Multiplier for matching plant diseases (default 1.5)
     * @returns Adjusted predictions with updated probabilities
     */
    export function applyPlantContext(
      predictions: Array<{ classIndex: number; probability: number }>,
      plantId: string | null,
      boostFactor: number = 1.5,
    ): Array<{ classIndex: number; probability: number; contextBoosted: boolean }> {
      if (!plantId) {
        return predictions.map((p) => ({ ...p, contextBoosted: false }));
      }
    
      // Find which class indices belong to this plant
      const plantIndices = new Set(
        PLANTVILLAGE_CLASSES.filter((c) => c.plantId === plantId && !c.isHealthy).map(
          (c) => c.index,
        ),
      );
    
      return predictions.map((pred) => {
        const matchesPlant = plantIndices.has(pred.classIndex);
        return {
          classIndex: pred.classIndex,
          probability: matchesPlant
            ? Math.min(1.0, pred.probability * boostFactor)
            : pred.probability,
          contextBoosted: matchesPlant,
        };
      });
    }
    
  2. Update /api/identify route to accept plantId:

    export async function POST(request: NextRequest) {
      const body = await request.json();
      const { imageId, plantId } = body; // plantId is optional
    
      // ... existing preprocessing ...
    
      const { probabilities, inferenceTimeMs } = await runInference(tensor);
    
      // Get top-K predictions
      const topK = getTopKFloat32(probabilities, 5);
    
      // Apply plant context if provided
      const adjusted = applyPlantContext(topK, plantId ?? null);
    
      // Enrich with KB data
      const predictions = await enrichPredictions(adjusted);
    
      return NextResponse.json({
        predictions,
        metadata: { model: MODEL_ID, inferenceTimeMs, imageId, plantContext: plantId ?? null },
      });
    }
    
  3. Update IdentifyRequest type:

    // src/lib/types.ts
    export interface IdentifyRequest {
      imageId: string;
      plantId?: string; // Optional plant context
    }
    
  4. Create PlantSelector component src/components/PlantSelector.tsx:

    "use client";
    
    import { useState, useEffect } from "react";
    
    interface Plant {
      id: string;
      commonName: string;
      imageUrl?: string;
    }
    
    export default function PlantSelector({
      value,
      onChange,
    }: {
      value: string | null;
      onChange: (plantId: string | null) => void;
    }) {
      const [plants, setPlants] = useState<Plant[]>([]);
      const [search, setSearch] = useState("");
    
      useEffect(() => {
        fetch("/api/plants?limit=50")
          .then((r) => r.json())
          .then((data) => setPlants(data.items ?? []));
      }, []);
    
      const filtered = plants.filter((p) =>
        p.commonName.toLowerCase().includes(search.toLowerCase()),
      );
    
      return (
        <div className="...">
          <label>Plant (optional)</label>
          <input
            type="text"
            placeholder="Search plants..."
            value={search}
            onChange={(e) => setSearch(e.target.value)}
          />
          {value && (
            <div className="...">
              Selected: {plants.find((p) => p.id === value)?.commonName}
              <button onClick={() => onChange(null)}>Clear</button>
            </div>
          )}
          <ul>
            {filtered.slice(0, 10).map((plant) => (
              <li key={plant.id} onClick={() => onChange(plant.id)}>
                {plant.commonName}
              </li>
            ))}
          </ul>
        </div>
      );
    }
    
  5. Update upload page to include plant selector:

    // src/app/upload/page.tsx
    export default function UploadPage() {
      const [selectedPlant, setSelectedPlant] = useState<string | null>(null);
    
      const handleUpload = useCallback(
        async (file: File) => {
          // 1. Upload image
          const uploadResponse = await uploadImage(file);
    
          // 2. Identify with plant context
          const identifyResponse = await identifyPlant(uploadResponse.imageId, selectedPlant);
    
          // 3. Navigate to results
          router.push(`/results/${uploadResponse.imageId}`);
        },
        [selectedPlant],
      );
    
      return (
        <div>
          <PlantSelector value={selectedPlant} onChange={setSelectedPlant} />
          <ImageUpload onUpload={handleUpload} />
        </div>
      );
    }
    
  6. Update client-side API to pass plantId:

    // src/lib/api/identify.ts
    export async function identifyPlant(
      imageId: string,
      plantId?: string,
    ): Promise<IdentifyResponse> {
      const body: IdentifyRequest = { imageId };
      if (plantId) body.plantId = plantId;
    
      const response = await fetch("/api/identify", {
        method: "POST",
        headers: { "Content-Type": "application/json" },
        body: JSON.stringify(body),
      });
    
      return response.json();
    }
    
  7. Update PredictionResult type to include context boost info:

    export interface PredictionResult {
      diseaseId: string;
      disease: Disease;
      confidence: ConfidenceResult;
      lookalikes: string[];
      plant: Plant | null;
      contextBoosted?: boolean; // true if boosted by plant context
    }
    
  8. Update ResultsDashboard to show context boost indicator:

    {
      prediction.contextBoosted && (
        <span className="text-xs text-leaf-green-600"> Matches selected plant</span>
      );
    }
    
  9. Store plant context in results page — pass plantId through URL or state:

    // src/app/results/[imageId]/page.tsx
    const plantId = searchParams.get("plant"); // optional
    const response = await identifyPlant(imageId, plantId);
    

tests:

  • Unit: applyPlantContext() with no plantId returns predictions unchanged
  • Unit: applyPlantContext() with plantId="tomato" boosts tomato disease predictions
  • Unit: boosted probabilities are capped at 1.0
  • Unit: non-matching plant predictions are unchanged
  • Unit: contextBoosted flag is set correctly
  • Integration: POST /api/identify with plantId returns boosted predictions
  • Integration: POST /api/identify without plantId returns normal predictions
  • E2E: select "Tomato" in UI → upload tomato leaf → tomato diseases appear first

acceptance_criteria:

  • Plant context is optional — identification works without it
  • When plantId is provided, predictions for that plant's diseases are boosted by 1.5x
  • Boosted probabilities are capped at 1.0
  • contextBoosted flag is set on boosted predictions
  • UI shows "Matches selected plant" indicator on boosted predictions
  • Plant selector component works (search, select, clear)
  • Upload flow includes optional plant selection step
  • Results page receives and displays plant context

validation:

  • npx vitest run src/lib/ml/plant-context.test.ts
  • npx vitest run src/components/PlantSelector.test.tsx
  • Manual: select "Tomato" → upload image → tomato diseases appear with boost indicator
  • Manual: don't select plant → upload image → normal predictions (no boost)
  • Check API response: predictions[0].contextBoosted is true when plant matches

notes:

  • Plant context is a scoring heuristic, not a hard filter. It boosts confidence but doesn't exclude other predictions.
  • The default boost factor is 1.5 — this can be tuned based on user feedback.
  • Plant selector is optional — users can skip it and get unboosted predictions.
  • The plant context feature is most useful when the user knows what plant they're diagnosing but the model is uncertain between multiple diseases.
  • For PlantVillage, each plant has 19 diseases, so the boost is specific enough to be useful without being overly restrictive.