Files
plant-disease-id/tasks/hierarchical-model-upgrade/02-hierarchical-training.md

16 KiB
Raw Blame History

Phase 2 — Hierarchical Model Training

Blocked by: Phase 1 (dataset reorganization) Blocks: Phase 3 (export) Est. time: 3-5 days on Strix Halo (ROCm), or 4-6 days on RTX 3090 (CUDA) Machine: Strix Halo preferred (128GB unified memory + 8TB NVMe at 7,300 MB/s read — SSD is fast enough to stream entire dataset in ~62s)

Objective

Train a hierarchical Swin-Tiny model with two classification heads:

  1. Species head (~320 classes) — identifies the plant
  2. Disease heads (one per species, 30-300 classes each) — identifies the disease

Architecture

Input Image (224×224×3)
        │
        ▼
┌──────────────────────┐
│ Swin-Tiny Backbone   │  ← pretrained on ImageNet-21K
│ (timm library)       │     optional: fine-tune on iNaturalist
│ output: 768-dim      │
└──────────┬───────────┘
           │
    ┌──────┴──────┐
    ▼             ▼
┌─────────┐ ┌───────────────┐
│ Species │ │ Disease Head  │
│ Head    │ │ (routed by    │
│ 320 cls │ │ species ID)   │
└────┬────┘ └───────┬───────┘
     │              │
     ▼              ▼
 Species ID      Disease ID

Environment Setup

# On Strix Halo (ROCm)
python3 -m venv .hierarchical-venv
source .hierarchical-venv/bin/activate

# ROCm PyTorch (install from https://pytorch.org/get-started/locally/)
# ROCm 6.x + PyTorch 2.5+
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2

# Training libs
pip install pytorch-lightning timm transformers wandb
pip install albumentations opencv-python pillow

# Data loading (for SSD-optimized streaming)
pip install webdataset fsspec  # optional benchmarks

Alternative (RTX 3090 CUDA path):

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

Training Protocol

Stage A: Species Classifier (2 days)

Step Epochs LR Batch Size Details
Head warmup 5 1e-3 512 (Strix) / 256 (3090) Backbone frozen, train only species head
Full fine-tune 20 1e-4 → 1e-6 512 / 256 Unfreeze backbone, cosine LR schedule
Stage final 5 5e-6 512 / 256 Discriminative LR: backbone layers 0.1× head LR

Loss: Focal Loss (γ=2.0, α=0.25) — handles any class imbalance in species distribution.

Augmentation — Image Jittering, Degradation & Robustness:

Real-world plant photos vary dramatically: different cameras, lighting conditions, angles, weather, focus quality, and compression artifacts. Augmentation is not optional — it's essential for generalization. The more varied your augmentation, the more robust your model will be when deployed.

Three tiers of augmentation, all applied on-the-fly (never pre-generated):

Tier 1 — Core Geometric & Photometric (applied to every image)

These simulate the most common real-world variations:

Augmentation Parameter Simulates
RandomResizedCrop scale=(0.6, 1.0), ratio=(0.75, 1.33) Different shooting distances, zoom levels, framing
HorizontalFlip p=0.5 Different leaf orientations (left/right symmetry)
Rotate limit=45°, p=0.5 Off-angle photos, tilted camera
ColorJitter brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 Different lighting — sunny, overcast, shade, dusk
RandomBrightnessContrast brightness_limit=0.2, contrast_limit=0.2, p=0.5 Exposure variations from auto-exposure cameras

Tier 2 — Degradation & Quality Simulation (applied to ~30% of images)

These make the model robust to poor-quality inputs that real users will upload:

Augmentation Parameter Simulates
GaussianBlur blur_limit=(3, 7), p=0.2 Out-of-focus photos, motion blur
GaussianNoise var_limit=(10, 50), p=0.15 Low-light sensor noise, phone camera noise
ImageCompression quality_lower=60, quality_upper=95, p=0.2 JPEG artifacts from compression, social media re-encoding
RandomGrayscale p=0.05 Monochrome cameras, infrared plant imaging
RandomShadow shadow_roi=(0, 1, 0, 1), p=0.15 Leaves in shadow of other leaves/structures outdoors

Tier 3 — Advanced Regularization (applied at batch level)

These are cutting-edge techniques that significantly improve generalization on fine-grained classification:

Technique Parameter Effect
MixUp α=0.2 Blends two random images + labels linearly. Forces the model to learn smooth decision boundaries. Proven +4.8% improvement on rare plant diseases.
CutMix α=1.0 Replaces a random patch of one image with another. Forces the model to focus on local lesion features rather than overall leaf shape.
RandAugment N=2, M=9 Auto-selected augmentation policy. 14 operations randomly chosen (shear, translate, rotate, contrast, etc.). N=2 ops per image, magnitude 9 (on 0-10 scale).
Label Smoothing ε=0.1 Prevents overconfidence on training classes, improves calibration on unseen diseases.

Implementation (albumentations):

import albumentations as A
from albumentations.pytorch import ToTensorV2
import kornia.augmentation as K  # for GPU-based MixUp/CutMix

# Core spatial + photometric (Tier 1+2)
train_transform = A.Compose([
    A.RandomResizedCrop(224, 224, scale=(0.6, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=45, p=0.5),
    A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
    # Degradation (Tier 2) — only some images get these
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 7), p=1.0),
        A.GaussianNoise(var_limit=(10, 50), p=1.0),
        A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
    ], p=0.3),
    A.ImageCompression(quality_lower=60, quality_upper=95, p=0.2),
    A.RandomShadow(shadow_roi=(0, 0.5, 0.5, 1), num_shadows_lower=1, num_shadows_upper=2, p=0.15),
    A.RandomGrayscale(p=0.05),
    # Normalize
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

# Validation — minimal, deterministic
val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

MixUp/CutMix (implemented in training loop, applied on GPU):

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

for images, labels in dataloader:
    images, labels = images.to(device), labels.to(device)

    if use_mixup and np.random.random() < 0.5:
        # MixUp: blend images and labels
        lam = np.random.beta(0.2, 0.2)
        indices = torch.randperm(images.size(0)).to(device)
        mixed_images = lam * images + (1 - lam) * images[indices]

        logits = model(mixed_images)
        loss = mixup_criterion(criterion, logits, labels, labels[indices], lam)
    else:
        logits = model(images)
        loss = criterion(logits, labels)

SSD Data Loading Strategy

Important: The full dataset is ~70GB (after Phase 1 resizing), which exceeds the 128GB RAM when accounting for OS, GPU memory, workspace, and model weights. However, your 8TB NVMe at 7,300 MB/s read changes everything.

Metric Value
Dataset size (after resize) ~70 GB
NVMe read speed 7,300 MB/s
Sequential read time (full dataset) ~10 seconds
Random read (1000 random files, 4KB each) ~0.5ms seek

Key insight: The GPU consumes batches slower than the SSD can deliver them. With num_workers=8, each worker reads ~35 images/s from random positions. At 7,300 MB/s sequential, the SSD can serve 150,000+ images/s. The bottleneck is JPEG decode + augmentation, not disk I/O.

Recommended DataLoader configuration:

dataloader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    # SSD-optimized settings:
    num_workers=8,              # 8 parallel readers — enough to saturate GPU
    prefetch_factor=4,          # Each worker prefetches 4 batches ahead
    pin_memory=True,            # Faster CPU→GPU transfer via DMA
    persistent_workers=True,    # Keep workers alive between epochs (avoid fork overhead)
    drop_last=True,             # Drop incomplete final batch for consistent batch norm
)

Why not load everything into RAM?

  • 128GB total memory — after OS (~8GB), GPU reserved (~4-8GB ROCm), model weights + optimizer states (~4GB), augmentation workspace (~2GB), you have ~100GB free
  • 70GB dataset would barely fit, but leaves no room for caching augmentation results or handling spikes
  • Better approach: let the NVMe + DataLoader pipeline stream data. At 7,300 MB/s, reading a batch of 256 images (~50MB) takes ~7ms. Meanwhile, the GPU takes ~200ms to process that batch. The disk is 30× faster than the GPU — you will never be I/O bound.

Optional: Use WebDataset for maximum throughput

WebDataset shards the dataset into large tar files (~1GB each), which are sequentially read. This eliminates random seek overhead entirely — ideal when running at massive scale. For your setup it's optional (raw files on NVMe are already fast enough), but worth considering if you scale to multi-GPU:

pip install webdataset
import webdataset as wds

urls = "data/organized/train/shard-{000000..000099}.tar"
dataset = wds.WebDataset(urls).shuffle(10000).decode("pil").to_tuple("jpg", "cls").map(augment)

Profiling check: During training, monitor GPU utilization:

  • nvidia-smi / rocm-smi — GPU-Util should be >90%
  • If <70%, GPU is waiting for data → increase num_workers or prefetch_factor
  • If >95%, data pipeline is keeping up → optimal

Stage B: Disease Classifiers (2-3 days)

Step Epochs LR Details
All disease heads 15 1e-3 Backbone frozen, train all disease heads simultaneously
Rare-class boost 5 5e-4 Oversample classes with <80 images
End-to-end 10 1e-5 Unfreeze backbone, joint species + disease loss

Key design: Disease heads are simple linear layers (768 → num_diseases_for_species). Since they share the backbone, inference is efficient — one forward pass through Swin-Tiny, then route the 768-dim feature vector to the correct head.

Class balancing: Use weighted sampler for disease heads — classes with <80 images get 3× sampling weight, classes with <50 images get 5×.

Loss weighting: L_total = L_species + 0.7 * L_disease — species loss has higher weight since disease prediction depends on correct species ID.

Model Checkpointing

checkpoints/
├── species_only/          # Stage A checkpoints
│   ├── epoch=05-val_loss=0.42.ckpt
│   ├── epoch=15-val_loss=0.18.ckpt
│   └── epoch=25-best.ckpt
├── disease_heads/          # Stage B initial disease heads
│   ├── disease_heads_epoch=15.pt
│   └── disease_heads_final.pt
└── hierarchical_full/      # End-to-end
    ├── epoch=05.ckpt
    └── epoch=10-best.ckpt

Save every checkpoint with species accuracy, macro F1, and per-disease F1 for the tail classes.

Expected Training Time (RTX 3090 baseline)

Stage Epochs Time/Epoch Total
Species head warmup 5 ~18 min 1.5 hr
Species full fine-tune 20 ~45 min 15 hr
Species fine-tune final 5 ~45 min 3.75 hr
Disease heads 15 ~30 min 7.5 hr
Disease rare-class 5 ~35 min 3 hr
End-to-end 10 ~50 min 8.5 hr
Total 60 ~39 hr

On Strix Halo with NVMe + ROCm the time/epoch should be significantly faster due to:

  • 7,300 MB/s NVMe: data loads faster than GPU can consume it (zero I/O wait)
  • Larger batch sizes (512 vs 256): fewer iterations per epoch
  • ROCm 6.x has strong PyTorch performance on AMD GPUs
  • 128GB RAM allows large prefetch buffer for seamless streaming

Expect ~20-28 hours total on Strix Halo.

Evaluation Metrics

Metric Target Measurement
Species top-1 accuracy ≥95% Fraction of correct species predictions
Disease top-1 accuracy ≥88% Across all species-conditioned heads
Disease top-3 accuracy ≥94% Model correct if disease is in top 3
Macro F1 (rare diseases) ≥80% Weighted average across tail classes
Species→Disease cascade accuracy ≥90% P(correct species) × P(correct disease correct species)

Edge Cases & Gotchas

  • GPU memory on RTX 3090 (24GB): Swin-Tiny at 224px with batch size 256 + mixed precision should fit. If not, reduce to 128 or use gradient accumulation (accumulate 2 steps).
  • Strix Halo ROCm quirks: torch.compile() may have issues on ROCm 6.2 — test without it first. Some timm model ops may need ROCm kernel fallbacks; test the forward pass before starting training.
  • Checkpoint compatibility: Save in pure PyTorch format (.pt), not Lightning-specific, so they're loadable outside Lightning for export.
  • Disease head memory: 320 separate linear layers sounds large, but each is 768×N_diseases (avg ~300 → 230K params). Total disease head params: ~70M (vs 28M for backbone). This is fine — compute is dominated by the backbone.
  • Loss divergence on rare diseases: Monitor individual disease loss curves; if a tail class diverges, reduce its learning rate or use gradient clipping (max_norm=1.0).

Verification

  • Species classifier ≥95% top-1 on val set
  • Disease top-3 accuracy ≥94% on val set
  • Confusion matrix shows no systematic species misclassifications
  • Per-species disease classifiers all converge (no NaN losses)
  • Tail classes (≤80 images) have F1 ≥70%
  • Model can be loaded from checkpoint and run inference in PyTorch
  • No OOM errors during training