Files
plant-disease-id/tasks/hierarchical-model-upgrade/02-hierarchical-training.md

310 lines
16 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Phase 2 — Hierarchical Model Training
**Blocked by**: Phase 1 (dataset reorganization)
**Blocks**: Phase 3 (export)
**Est. time**: 3-5 days on Strix Halo (ROCm), or 4-6 days on RTX 3090 (CUDA)
**Machine**: Strix Halo preferred (128GB unified memory + 8TB NVMe at 7,300 MB/s read — SSD is fast enough to stream entire dataset in ~62s)
## Objective
Train a hierarchical Swin-Tiny model with two classification heads:
1. **Species head** (~320 classes) — identifies the plant
2. **Disease heads** (one per species, 30-300 classes each) — identifies the disease
## Architecture
```
Input Image (224×224×3)
┌──────────────────────┐
│ Swin-Tiny Backbone │ ← pretrained on ImageNet-21K
│ (timm library) │ optional: fine-tune on iNaturalist
│ output: 768-dim │
└──────────┬───────────┘
┌──────┴──────┐
▼ ▼
┌─────────┐ ┌───────────────┐
│ Species │ │ Disease Head │
│ Head │ │ (routed by │
│ 320 cls │ │ species ID) │
└────┬────┘ └───────┬───────┘
│ │
▼ ▼
Species ID Disease ID
```
## Environment Setup
```bash
# On Strix Halo (ROCm)
python3 -m venv .hierarchical-venv
source .hierarchical-venv/bin/activate
# ROCm PyTorch (install from https://pytorch.org/get-started/locally/)
# ROCm 6.x + PyTorch 2.5+
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2
# Training libs
pip install pytorch-lightning timm transformers wandb
pip install albumentations opencv-python pillow
# Data loading (for SSD-optimized streaming)
pip install webdataset fsspec # optional benchmarks
```
**Alternative (RTX 3090 CUDA path)**:
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
```
## Training Protocol
### Stage A: Species Classifier (2 days)
| Step | Epochs | LR | Batch Size | Details |
| -------------- | ------ | ----------- | ------------------------ | ----------------------------------------------- |
| Head warmup | 5 | 1e-3 | 512 (Strix) / 256 (3090) | Backbone frozen, train only species head |
| Full fine-tune | 20 | 1e-4 → 1e-6 | 512 / 256 | Unfreeze backbone, cosine LR schedule |
| Stage final | 5 | 5e-6 | 512 / 256 | Discriminative LR: backbone layers 0.1× head LR |
**Loss**: Focal Loss (γ=2.0, α=0.25) — handles any class imbalance in species distribution.
**Augmentation — Image Jittering, Degradation & Robustness**:
Real-world plant photos vary dramatically: different cameras, lighting conditions, angles, weather, focus quality, and compression artifacts. **Augmentation is not optional — it's essential for generalization.** The more varied your augmentation, the more robust your model will be when deployed.
**Three tiers of augmentation**, all applied on-the-fly (never pre-generated):
#### Tier 1 — Core Geometric & Photometric (applied to every image)
These simulate the most common real-world variations:
| Augmentation | Parameter | Simulates |
| ------------------------ | ----------------------------------------------------- | -------------------------------------------------- |
| RandomResizedCrop | scale=(0.6, 1.0), ratio=(0.75, 1.33) | Different shooting distances, zoom levels, framing |
| HorizontalFlip | p=0.5 | Different leaf orientations (left/right symmetry) |
| Rotate | limit=45°, p=0.5 | Off-angle photos, tilted camera |
| ColorJitter | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 | Different lighting — sunny, overcast, shade, dusk |
| RandomBrightnessContrast | brightness_limit=0.2, contrast_limit=0.2, p=0.5 | Exposure variations from auto-exposure cameras |
#### Tier 2 — Degradation & Quality Simulation (applied to ~30% of images)
These make the model robust to poor-quality inputs that real users will upload:
| Augmentation | Parameter | Simulates |
| ---------------- | ----------------------------------------- | --------------------------------------------------------- |
| GaussianBlur | blur_limit=(3, 7), p=0.2 | Out-of-focus photos, motion blur |
| GaussianNoise | var_limit=(10, 50), p=0.15 | Low-light sensor noise, phone camera noise |
| ImageCompression | quality_lower=60, quality_upper=95, p=0.2 | JPEG artifacts from compression, social media re-encoding |
| RandomGrayscale | p=0.05 | Monochrome cameras, infrared plant imaging |
| RandomShadow | shadow_roi=(0, 1, 0, 1), p=0.15 | Leaves in shadow of other leaves/structures outdoors |
#### Tier 3 — Advanced Regularization (applied at batch level)
These are cutting-edge techniques that significantly improve generalization on fine-grained classification:
| Technique | Parameter | Effect |
| ------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **MixUp** | α=0.2 | Blends two random images + labels linearly. Forces the model to learn smooth decision boundaries. Proven +4.8% improvement on rare plant diseases. |
| **CutMix** | α=1.0 | Replaces a random patch of one image with another. Forces the model to focus on local lesion features rather than overall leaf shape. |
| **RandAugment** | N=2, M=9 | Auto-selected augmentation policy. 14 operations randomly chosen (shear, translate, rotate, contrast, etc.). N=2 ops per image, magnitude 9 (on 0-10 scale). |
| **Label Smoothing** | ε=0.1 | Prevents overconfidence on training classes, improves calibration on unseen diseases. |
**Implementation (albumentations)**:
```python
import albumentations as A
from albumentations.pytorch import ToTensorV2
import kornia.augmentation as K # for GPU-based MixUp/CutMix
# Core spatial + photometric (Tier 1+2)
train_transform = A.Compose([
A.RandomResizedCrop(224, 224, scale=(0.6, 1.0), ratio=(0.75, 1.33)),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=45, p=0.5),
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
# Degradation (Tier 2) — only some images get these
A.OneOf([
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
A.GaussianNoise(var_limit=(10, 50), p=1.0),
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
], p=0.3),
A.ImageCompression(quality_lower=60, quality_upper=95, p=0.2),
A.RandomShadow(shadow_roi=(0, 0.5, 0.5, 1), num_shadows_lower=1, num_shadows_upper=2, p=0.15),
A.RandomGrayscale(p=0.05),
# Normalize
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
# Validation — minimal, deterministic
val_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
```
**MixUp/CutMix (implemented in training loop, applied on GPU)**:
```python
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
if use_mixup and np.random.random() < 0.5:
# MixUp: blend images and labels
lam = np.random.beta(0.2, 0.2)
indices = torch.randperm(images.size(0)).to(device)
mixed_images = lam * images + (1 - lam) * images[indices]
logits = model(mixed_images)
loss = mixup_criterion(criterion, logits, labels, labels[indices], lam)
else:
logits = model(images)
loss = criterion(logits, labels)
```
## SSD Data Loading Strategy
**Important**: The full dataset is ~70GB (after Phase 1 resizing), which exceeds the 128GB RAM when accounting for OS, GPU memory, workspace, and model weights. However, **your 8TB NVMe at 7,300 MB/s read changes everything.**
| Metric | Value |
| ----------------------------------------- | --------------- |
| Dataset size (after resize) | ~70 GB |
| NVMe read speed | 7,300 MB/s |
| Sequential read time (full dataset) | **~10 seconds** |
| Random read (1000 random files, 4KB each) | ~0.5ms seek |
**Key insight**: The GPU consumes batches slower than the SSD can deliver them. With `num_workers=8`, each worker reads ~35 images/s from random positions. At 7,300 MB/s sequential, the SSD can serve 150,000+ images/s. The bottleneck is **JPEG decode + augmentation**, not disk I/O.
**Recommended DataLoader configuration**:
```python
dataloader = DataLoader(
dataset,
batch_size=256,
shuffle=True,
# SSD-optimized settings:
num_workers=8, # 8 parallel readers — enough to saturate GPU
prefetch_factor=4, # Each worker prefetches 4 batches ahead
pin_memory=True, # Faster CPU→GPU transfer via DMA
persistent_workers=True, # Keep workers alive between epochs (avoid fork overhead)
drop_last=True, # Drop incomplete final batch for consistent batch norm
)
```
**Why not load everything into RAM?**
- 128GB total memory — after OS (~8GB), GPU reserved (~4-8GB ROCm), model weights + optimizer states (~4GB), augmentation workspace (~2GB), you have ~100GB free
- 70GB dataset would barely fit, but leaves no room for caching augmentation results or handling spikes
- Better approach: let the NVMe + DataLoader pipeline stream data. At 7,300 MB/s, reading a batch of 256 images (~50MB) takes ~7ms. Meanwhile, the GPU takes ~200ms to process that batch. **The disk is 30× faster than the GPU — you will never be I/O bound.**
**Optional: Use WebDataset for maximum throughput**
WebDataset shards the dataset into large tar files (~1GB each), which are sequentially read. This eliminates random seek overhead entirely — ideal when running at massive scale. For your setup it's optional (raw files on NVMe are already fast enough), but worth considering if you scale to multi-GPU:
```bash
pip install webdataset
```
```python
import webdataset as wds
urls = "data/organized/train/shard-{000000..000099}.tar"
dataset = wds.WebDataset(urls).shuffle(10000).decode("pil").to_tuple("jpg", "cls").map(augment)
```
**Profiling check**: During training, monitor GPU utilization:
- `nvidia-smi` / `rocm-smi` — GPU-Util should be >90%
- If <70%, GPU is waiting for data → increase `num_workers` or `prefetch_factor`
- If >95%, data pipeline is keeping up → optimal
### Stage B: Disease Classifiers (2-3 days)
| Step | Epochs | LR | Details |
| ----------------- | ------ | ---- | ------------------------------------------------------- |
| All disease heads | 15 | 1e-3 | Backbone frozen, train all disease heads simultaneously |
| Rare-class boost | 5 | 5e-4 | Oversample classes with <80 images |
| End-to-end | 10 | 1e-5 | Unfreeze backbone, joint species + disease loss |
**Key design**: Disease heads are simple linear layers (768 → num_diseases_for_species). Since they share the backbone, inference is efficient — one forward pass through Swin-Tiny, then route the 768-dim feature vector to the correct head.
**Class balancing**: Use weighted sampler for disease heads — classes with <80 images get 3× sampling weight, classes with <50 images get 5×.
**Loss weighting**: `L_total = L_species + 0.7 * L_disease` — species loss has higher weight since disease prediction depends on correct species ID.
## Model Checkpointing
```
checkpoints/
├── species_only/ # Stage A checkpoints
│ ├── epoch=05-val_loss=0.42.ckpt
│ ├── epoch=15-val_loss=0.18.ckpt
│ └── epoch=25-best.ckpt
├── disease_heads/ # Stage B initial disease heads
│ ├── disease_heads_epoch=15.pt
│ └── disease_heads_final.pt
└── hierarchical_full/ # End-to-end
├── epoch=05.ckpt
└── epoch=10-best.ckpt
```
Save every checkpoint with species accuracy, macro F1, and per-disease F1 for the tail classes.
## Expected Training Time (RTX 3090 baseline)
| Stage | Epochs | Time/Epoch | Total |
| ----------------------- | ------ | ---------- | ---------- |
| Species head warmup | 5 | ~18 min | 1.5 hr |
| Species full fine-tune | 20 | ~45 min | 15 hr |
| Species fine-tune final | 5 | ~45 min | 3.75 hr |
| Disease heads | 15 | ~30 min | 7.5 hr |
| Disease rare-class | 5 | ~35 min | 3 hr |
| End-to-end | 10 | ~50 min | 8.5 hr |
| **Total** | **60** | | **~39 hr** |
On **Strix Halo with NVMe + ROCm** the time/epoch should be **significantly faster** due to:
- 7,300 MB/s NVMe: data loads faster than GPU can consume it (zero I/O wait)
- Larger batch sizes (512 vs 256): fewer iterations per epoch
- ROCm 6.x has strong PyTorch performance on AMD GPUs
- 128GB RAM allows large prefetch buffer for seamless streaming
Expect **~20-28 hours** total on Strix Halo.
## Evaluation Metrics
| Metric | Target | Measurement |
| -------------------------------- | ------ | --------------------------------------- | ---------------- |
| Species top-1 accuracy | ≥95% | Fraction of correct species predictions |
| Disease top-1 accuracy | ≥88% | Across all species-conditioned heads |
| Disease top-3 accuracy | ≥94% | Model correct if disease is in top 3 |
| Macro F1 (rare diseases) | ≥80% | Weighted average across tail classes |
| Species→Disease cascade accuracy | ≥90% | P(correct species) × P(correct disease | correct species) |
## Edge Cases & Gotchas
- **GPU memory on RTX 3090 (24GB)**: Swin-Tiny at 224px with batch size 256 + mixed precision should fit. If not, reduce to 128 or use gradient accumulation (accumulate 2 steps).
- **Strix Halo ROCm quirks**: `torch.compile()` may have issues on ROCm 6.2 — test without it first. Some `timm` model ops may need ROCm kernel fallbacks; test the forward pass before starting training.
- **Checkpoint compatibility**: Save in pure PyTorch format (`.pt`), not Lightning-specific, so they're loadable outside Lightning for export.
- **Disease head memory**: 320 separate linear layers sounds large, but each is 768×N_diseases (avg ~300 → 230K params). Total disease head params: ~70M (vs 28M for backbone). This is fine — compute is dominated by the backbone.
- **Loss divergence on rare diseases**: Monitor individual disease loss curves; if a tail class diverges, reduce its learning rate or use gradient clipping (max_norm=1.0).
## Verification
- [ ] Species classifier ≥95% top-1 on val set
- [ ] Disease top-3 accuracy ≥94% on val set
- [ ] Confusion matrix shows no systematic species misclassifications
- [ ] Per-species disease classifiers all converge (no NaN losses)
- [ ] Tail classes (≤80 images) have F1 ≥70%
- [ ] Model can be loaded from checkpoint and run inference in PyTorch
- [ ] No OOM errors during training