task to get this here done

This commit is contained in:
2026-06-12 13:20:33 -04:00
parent 6379860123
commit 34855eff55
7 changed files with 1307 additions and 85 deletions

View File

@@ -0,0 +1,309 @@
# Phase 2 — Hierarchical Model Training
**Blocked by**: Phase 1 (dataset reorganization)
**Blocks**: Phase 3 (export)
**Est. time**: 3-5 days on Strix Halo (ROCm), or 4-6 days on RTX 3090 (CUDA)
**Machine**: Strix Halo preferred (128GB unified memory + 8TB NVMe at 7,300 MB/s read — SSD is fast enough to stream entire dataset in ~62s)
## Objective
Train a hierarchical Swin-Tiny model with two classification heads:
1. **Species head** (~320 classes) — identifies the plant
2. **Disease heads** (one per species, 30-300 classes each) — identifies the disease
## Architecture
```
Input Image (224×224×3)
┌──────────────────────┐
│ Swin-Tiny Backbone │ ← pretrained on ImageNet-21K
│ (timm library) │ optional: fine-tune on iNaturalist
│ output: 768-dim │
└──────────┬───────────┘
┌──────┴──────┐
▼ ▼
┌─────────┐ ┌───────────────┐
│ Species │ │ Disease Head │
│ Head │ │ (routed by │
│ 320 cls │ │ species ID) │
└────┬────┘ └───────┬───────┘
│ │
▼ ▼
Species ID Disease ID
```
## Environment Setup
```bash
# On Strix Halo (ROCm)
python3 -m venv .hierarchical-venv
source .hierarchical-venv/bin/activate
# ROCm PyTorch (install from https://pytorch.org/get-started/locally/)
# ROCm 6.x + PyTorch 2.5+
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2
# Training libs
pip install pytorch-lightning timm transformers wandb
pip install albumentations opencv-python pillow
# Data loading (for SSD-optimized streaming)
pip install webdataset fsspec # optional benchmarks
```
**Alternative (RTX 3090 CUDA path)**:
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
```
## Training Protocol
### Stage A: Species Classifier (2 days)
| Step | Epochs | LR | Batch Size | Details |
| -------------- | ------ | ----------- | ------------------------ | ----------------------------------------------- |
| Head warmup | 5 | 1e-3 | 512 (Strix) / 256 (3090) | Backbone frozen, train only species head |
| Full fine-tune | 20 | 1e-4 → 1e-6 | 512 / 256 | Unfreeze backbone, cosine LR schedule |
| Stage final | 5 | 5e-6 | 512 / 256 | Discriminative LR: backbone layers 0.1× head LR |
**Loss**: Focal Loss (γ=2.0, α=0.25) — handles any class imbalance in species distribution.
**Augmentation — Image Jittering, Degradation & Robustness**:
Real-world plant photos vary dramatically: different cameras, lighting conditions, angles, weather, focus quality, and compression artifacts. **Augmentation is not optional — it's essential for generalization.** The more varied your augmentation, the more robust your model will be when deployed.
**Three tiers of augmentation**, all applied on-the-fly (never pre-generated):
#### Tier 1 — Core Geometric & Photometric (applied to every image)
These simulate the most common real-world variations:
| Augmentation | Parameter | Simulates |
| ------------------------ | ----------------------------------------------------- | -------------------------------------------------- |
| RandomResizedCrop | scale=(0.6, 1.0), ratio=(0.75, 1.33) | Different shooting distances, zoom levels, framing |
| HorizontalFlip | p=0.5 | Different leaf orientations (left/right symmetry) |
| Rotate | limit=45°, p=0.5 | Off-angle photos, tilted camera |
| ColorJitter | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 | Different lighting — sunny, overcast, shade, dusk |
| RandomBrightnessContrast | brightness_limit=0.2, contrast_limit=0.2, p=0.5 | Exposure variations from auto-exposure cameras |
#### Tier 2 — Degradation & Quality Simulation (applied to ~30% of images)
These make the model robust to poor-quality inputs that real users will upload:
| Augmentation | Parameter | Simulates |
| ---------------- | ----------------------------------------- | --------------------------------------------------------- |
| GaussianBlur | blur_limit=(3, 7), p=0.2 | Out-of-focus photos, motion blur |
| GaussianNoise | var_limit=(10, 50), p=0.15 | Low-light sensor noise, phone camera noise |
| ImageCompression | quality_lower=60, quality_upper=95, p=0.2 | JPEG artifacts from compression, social media re-encoding |
| RandomGrayscale | p=0.05 | Monochrome cameras, infrared plant imaging |
| RandomShadow | shadow_roi=(0, 1, 0, 1), p=0.15 | Leaves in shadow of other leaves/structures outdoors |
#### Tier 3 — Advanced Regularization (applied at batch level)
These are cutting-edge techniques that significantly improve generalization on fine-grained classification:
| Technique | Parameter | Effect |
| ------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **MixUp** | α=0.2 | Blends two random images + labels linearly. Forces the model to learn smooth decision boundaries. Proven +4.8% improvement on rare plant diseases. |
| **CutMix** | α=1.0 | Replaces a random patch of one image with another. Forces the model to focus on local lesion features rather than overall leaf shape. |
| **RandAugment** | N=2, M=9 | Auto-selected augmentation policy. 14 operations randomly chosen (shear, translate, rotate, contrast, etc.). N=2 ops per image, magnitude 9 (on 0-10 scale). |
| **Label Smoothing** | ε=0.1 | Prevents overconfidence on training classes, improves calibration on unseen diseases. |
**Implementation (albumentations)**:
```python
import albumentations as A
from albumentations.pytorch import ToTensorV2
import kornia.augmentation as K # for GPU-based MixUp/CutMix
# Core spatial + photometric (Tier 1+2)
train_transform = A.Compose([
A.RandomResizedCrop(224, 224, scale=(0.6, 1.0), ratio=(0.75, 1.33)),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=45, p=0.5),
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
# Degradation (Tier 2) — only some images get these
A.OneOf([
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
A.GaussianNoise(var_limit=(10, 50), p=1.0),
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
], p=0.3),
A.ImageCompression(quality_lower=60, quality_upper=95, p=0.2),
A.RandomShadow(shadow_roi=(0, 0.5, 0.5, 1), num_shadows_lower=1, num_shadows_upper=2, p=0.15),
A.RandomGrayscale(p=0.05),
# Normalize
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
# Validation — minimal, deterministic
val_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
```
**MixUp/CutMix (implemented in training loop, applied on GPU)**:
```python
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
if use_mixup and np.random.random() < 0.5:
# MixUp: blend images and labels
lam = np.random.beta(0.2, 0.2)
indices = torch.randperm(images.size(0)).to(device)
mixed_images = lam * images + (1 - lam) * images[indices]
logits = model(mixed_images)
loss = mixup_criterion(criterion, logits, labels, labels[indices], lam)
else:
logits = model(images)
loss = criterion(logits, labels)
```
## SSD Data Loading Strategy
**Important**: The full dataset is ~70GB (after Phase 1 resizing), which exceeds the 128GB RAM when accounting for OS, GPU memory, workspace, and model weights. However, **your 8TB NVMe at 7,300 MB/s read changes everything.**
| Metric | Value |
| ----------------------------------------- | --------------- |
| Dataset size (after resize) | ~70 GB |
| NVMe read speed | 7,300 MB/s |
| Sequential read time (full dataset) | **~10 seconds** |
| Random read (1000 random files, 4KB each) | ~0.5ms seek |
**Key insight**: The GPU consumes batches slower than the SSD can deliver them. With `num_workers=8`, each worker reads ~35 images/s from random positions. At 7,300 MB/s sequential, the SSD can serve 150,000+ images/s. The bottleneck is **JPEG decode + augmentation**, not disk I/O.
**Recommended DataLoader configuration**:
```python
dataloader = DataLoader(
dataset,
batch_size=256,
shuffle=True,
# SSD-optimized settings:
num_workers=8, # 8 parallel readers — enough to saturate GPU
prefetch_factor=4, # Each worker prefetches 4 batches ahead
pin_memory=True, # Faster CPU→GPU transfer via DMA
persistent_workers=True, # Keep workers alive between epochs (avoid fork overhead)
drop_last=True, # Drop incomplete final batch for consistent batch norm
)
```
**Why not load everything into RAM?**
- 128GB total memory — after OS (~8GB), GPU reserved (~4-8GB ROCm), model weights + optimizer states (~4GB), augmentation workspace (~2GB), you have ~100GB free
- 70GB dataset would barely fit, but leaves no room for caching augmentation results or handling spikes
- Better approach: let the NVMe + DataLoader pipeline stream data. At 7,300 MB/s, reading a batch of 256 images (~50MB) takes ~7ms. Meanwhile, the GPU takes ~200ms to process that batch. **The disk is 30× faster than the GPU — you will never be I/O bound.**
**Optional: Use WebDataset for maximum throughput**
WebDataset shards the dataset into large tar files (~1GB each), which are sequentially read. This eliminates random seek overhead entirely — ideal when running at massive scale. For your setup it's optional (raw files on NVMe are already fast enough), but worth considering if you scale to multi-GPU:
```bash
pip install webdataset
```
```python
import webdataset as wds
urls = "data/organized/train/shard-{000000..000099}.tar"
dataset = wds.WebDataset(urls).shuffle(10000).decode("pil").to_tuple("jpg", "cls").map(augment)
```
**Profiling check**: During training, monitor GPU utilization:
- `nvidia-smi` / `rocm-smi` — GPU-Util should be >90%
- If <70%, GPU is waiting for data → increase `num_workers` or `prefetch_factor`
- If >95%, data pipeline is keeping up → optimal
### Stage B: Disease Classifiers (2-3 days)
| Step | Epochs | LR | Details |
| ----------------- | ------ | ---- | ------------------------------------------------------- |
| All disease heads | 15 | 1e-3 | Backbone frozen, train all disease heads simultaneously |
| Rare-class boost | 5 | 5e-4 | Oversample classes with <80 images |
| End-to-end | 10 | 1e-5 | Unfreeze backbone, joint species + disease loss |
**Key design**: Disease heads are simple linear layers (768 → num_diseases_for_species). Since they share the backbone, inference is efficient — one forward pass through Swin-Tiny, then route the 768-dim feature vector to the correct head.
**Class balancing**: Use weighted sampler for disease heads — classes with <80 images get 3× sampling weight, classes with <50 images get 5×.
**Loss weighting**: `L_total = L_species + 0.7 * L_disease` — species loss has higher weight since disease prediction depends on correct species ID.
## Model Checkpointing
```
checkpoints/
├── species_only/ # Stage A checkpoints
│ ├── epoch=05-val_loss=0.42.ckpt
│ ├── epoch=15-val_loss=0.18.ckpt
│ └── epoch=25-best.ckpt
├── disease_heads/ # Stage B initial disease heads
│ ├── disease_heads_epoch=15.pt
│ └── disease_heads_final.pt
└── hierarchical_full/ # End-to-end
├── epoch=05.ckpt
└── epoch=10-best.ckpt
```
Save every checkpoint with species accuracy, macro F1, and per-disease F1 for the tail classes.
## Expected Training Time (RTX 3090 baseline)
| Stage | Epochs | Time/Epoch | Total |
| ----------------------- | ------ | ---------- | ---------- |
| Species head warmup | 5 | ~18 min | 1.5 hr |
| Species full fine-tune | 20 | ~45 min | 15 hr |
| Species fine-tune final | 5 | ~45 min | 3.75 hr |
| Disease heads | 15 | ~30 min | 7.5 hr |
| Disease rare-class | 5 | ~35 min | 3 hr |
| End-to-end | 10 | ~50 min | 8.5 hr |
| **Total** | **60** | | **~39 hr** |
On **Strix Halo with NVMe + ROCm** the time/epoch should be **significantly faster** due to:
- 7,300 MB/s NVMe: data loads faster than GPU can consume it (zero I/O wait)
- Larger batch sizes (512 vs 256): fewer iterations per epoch
- ROCm 6.x has strong PyTorch performance on AMD GPUs
- 128GB RAM allows large prefetch buffer for seamless streaming
Expect **~20-28 hours** total on Strix Halo.
## Evaluation Metrics
| Metric | Target | Measurement |
| -------------------------------- | ------ | --------------------------------------- | ---------------- |
| Species top-1 accuracy | ≥95% | Fraction of correct species predictions |
| Disease top-1 accuracy | ≥88% | Across all species-conditioned heads |
| Disease top-3 accuracy | ≥94% | Model correct if disease is in top 3 |
| Macro F1 (rare diseases) | ≥80% | Weighted average across tail classes |
| Species→Disease cascade accuracy | ≥90% | P(correct species) × P(correct disease | correct species) |
## Edge Cases & Gotchas
- **GPU memory on RTX 3090 (24GB)**: Swin-Tiny at 224px with batch size 256 + mixed precision should fit. If not, reduce to 128 or use gradient accumulation (accumulate 2 steps).
- **Strix Halo ROCm quirks**: `torch.compile()` may have issues on ROCm 6.2 — test without it first. Some `timm` model ops may need ROCm kernel fallbacks; test the forward pass before starting training.
- **Checkpoint compatibility**: Save in pure PyTorch format (`.pt`), not Lightning-specific, so they're loadable outside Lightning for export.
- **Disease head memory**: 320 separate linear layers sounds large, but each is 768×N_diseases (avg ~300 → 230K params). Total disease head params: ~70M (vs 28M for backbone). This is fine — compute is dominated by the backbone.
- **Loss divergence on rare diseases**: Monitor individual disease loss curves; if a tail class diverges, reduce its learning rate or use gradient clipping (max_norm=1.0).
## Verification
- [ ] Species classifier ≥95% top-1 on val set
- [ ] Disease top-3 accuracy ≥94% on val set
- [ ] Confusion matrix shows no systematic species misclassifications
- [ ] Per-species disease classifiers all converge (no NaN losses)
- [ ] Tail classes (≤80 images) have F1 ≥70%
- [ ] Model can be loaded from checkpoint and run inference in PyTorch
- [ ] No OOM errors during training