Files
plant-disease-id/scripts/train_hierarchical.py

1038 lines
40 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Phase 2 — Hierarchical Model Training (Swin-Tiny).
Trains a hierarchical plant disease classifier with:
- Swin-Tiny backbone (timm, ImageNet-21K pretrained)
- Species head (~530 classes — identifies the plant)
- Disease heads (one per species, 30-300 classes each — identifies the disease)
Two-stage training protocol:
Stage A (Species): Train species classifier with frozen backbone, then full fine-tune.
Stage B (Disease): Train disease-specific heads, then end-to-end joint training.
Usage:
python3 scripts/train_hierarchical.py # Full training
python3 scripts/train_hierarchical.py --stage species # Stage A only
python3 scripts/train_hierarchical.py --stage disease # Stage B only (requires Stage A checkpoint)
python3 scripts/train_hierarchical.py --resume path.ckpt # Resume from checkpoint
Requirements:
pip install torch torchvision timm pytorch-lightning albumentations wandb
"""
import argparse
import json
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
import albumentations as A
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from tqdm import tqdm
# Conditionally import pytorch-lightning (optional — for multi-GPU)
try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
HAS_LIGHTNING = True
except ImportError:
HAS_LIGHTNING = False
# Conditionally import wandb
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
# ─── Config ───────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = BASE_DIR / "data" / "organized"
TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
CHECKPOINT_DIR = BASE_DIR / "checkpoints"
LOG_DIR = BASE_DIR / "logs"
# ImageNet normalization constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Default training config
DEFAULT_CONFIG = {
# Model
"backbone": "swin_tiny_patch4_window7_224",
"img_size": 224,
"dropout": 0.1,
# Stage A (Species)
"species_epochs_warmup": 5,
"species_epochs_finetune": 20,
"species_epochs_final": 5,
"species_lr_warmup": 1e-3,
"species_lr_finetune": 1e-4,
"species_lr_final": 5e-6,
"species_batch_size": 512,
# Stage B (Disease)
"disease_epochs_heads": 15,
"disease_epochs_rare": 5,
"disease_epochs_e2e": 10,
"disease_lr_heads": 1e-3,
"disease_lr_rare": 5e-4,
"disease_lr_e2e": 1e-5,
"disease_loss_weight": 0.7,
# Data
"num_workers": 8,
"prefetch_factor": 4,
"val_split_pct": 0.15,
# Augmentation
"mixup_alpha": 0.2,
"cutmix_alpha": 1.0,
"label_smoothing": 0.1,
# Training
"gradient_clip_val": 1.0,
"weight_decay": 0.05,
"warmup_epochs": 3,
# System
"seed": 42,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"precision": "16-mixed", # mixed precision training
"accumulate_grad_batches": 1,
}
# ─── Dataset ──────────────────────────────────────────────────────────────────
class HierarchicalDataset(Dataset):
"""
Reads organized dataset from data/organized/{train,val}/{species}/{disease}/.
Loads species_index.json to build class mappings, then walks directories.
"""
def __init__(self, data_dir: Path, split: str, transform=None, config=None):
self.data_dir = data_dir / split
self.split = split
self.transform = transform
self.config = config or DEFAULT_CONFIG
# Load metadata
self.species_index = self._load_json(data_dir / "species_index.json")
self.dataset_stats = self._load_json(data_dir / "dataset_stats.json")
# Build mappings
self.species_list = sorted(self.species_index.keys())
self.species_to_idx = {s: i for i, s in enumerate(self.species_list)}
self.num_species = len(self.species_list)
# Build disease mappings per species
self.disease_to_idx = {} # species → {disease_name → index}
self.disease_counts = {} # species → {disease_name → count}
self.samples = [] # List of (image_path, species_idx, disease_name)
self._build_index()
def _load_json(self, path):
with open(path) as f:
return json.load(f)
def _build_index(self):
"""Walk the data directory and build sample index."""
species_dirs = sorted([
d for d in os.listdir(self.data_dir)
if os.path.isdir(self.data_dir / d) and not d.startswith(".")
])
# Add any species from species_index that exist as directories
for species in species_dirs:
if species not in self.species_to_idx:
print(f" [WARN] Species '{species}' not in species_index, skipping")
continue
species_idx = self.species_to_idx[species]
disease_dirs = sorted([
d for d in os.listdir(self.data_dir / species)
if os.path.isdir(self.data_dir / species / d) and not d.startswith(".")
])
for disease in disease_dirs:
if disease not in self.disease_to_idx.get(species, {}):
self.disease_to_idx.setdefault(species, {})[disease] = len(
self.disease_to_idx.get(species, {})
)
disease_dir = self.data_dir / species / disease
image_files = sorted([
f for f in os.listdir(disease_dir)
if os.path.isfile(disease_dir / f)
and f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
])
self.disease_counts.setdefault(species, {}).setdefault(disease, 0)
self.disease_counts[species][disease] += len(image_files)
for img_file in image_files:
self.samples.append((
disease_dir / img_file,
species_idx,
disease,
species,
))
print(f" [{self.split}] {len(self.samples)} images, "
f"{self.num_species} species")
def get_species_class_distribution(self):
"""Return list of sample counts per species."""
counts = [0] * self.num_species
for _, sp_idx, _, _ in self.samples:
counts[sp_idx] += 1
return counts
def get_disease_class_weights(self, species):
"""Return per-sample weights for disease class balancing."""
counts = self.disease_counts.get(species, {})
if not counts:
return None
total = sum(counts.values())
weights = {}
for disease, count in counts.items():
# Weight inversely proportional to sqrt(count) — moderate upweighting
weights[disease] = total / (len(counts) * np.sqrt(count) * np.sqrt(count / total))
return weights
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, species_idx, disease_name, species = self.samples[idx]
# Load image
image = Image.open(img_path).convert("RGB")
image_np = np.array(image)
# Apply transforms
if self.transform:
augmented = self.transform(image=image_np)
image_tensor = augmented["image"]
else:
# Default: resize + normalize
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
image = image.resize((224, 224), Image.BILINEAR)
image_tensor = TF.to_tensor(image)
image_tensor = TF.normalize(image_tensor, IMAGENET_MEAN, IMAGENET_STD)
# Disease label
disease_idx = self.disease_to_idx[species].get(disease_name, 0)
num_diseases = len(self.disease_to_idx[species])
return {
"image": image_tensor,
"species_idx": torch.tensor(species_idx, dtype=torch.long),
"disease_idx": torch.tensor(disease_idx, dtype=torch.long),
"disease_name": disease_name,
"species": species,
"num_diseases": num_diseases,
}
# ─── Augmentations ────────────────────────────────────────────────────────────
def get_train_transform(config=None):
"""Get training augmentation pipeline (Tier 1 + 2)."""
config = config or DEFAULT_CONFIG
img_size = config.get("img_size", 224)
return A.Compose([
# Tier 1 — Core geometric & photometric
A.RandomResizedCrop(size=(img_size, img_size), scale=(0.6, 1.0), ratio=(0.75, 1.33)),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=45, p=0.5),
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
# Tier 2 — Degradation & quality simulation
A.OneOf([
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
A.GaussNoise(std_range=(0.01, 0.05), p=1.0),
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
], p=0.3),
A.ImageCompression(quality_range=(60, 95), p=0.2),
A.RandomShadow(p=0.15),
A.ToGray(p=0.05),
# Normalize
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
def get_val_transform(config=None):
"""Get validation transform (minimal, deterministic)."""
config = config or DEFAULT_CONFIG
img_size = config.get("img_size", 224)
return A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
# ─── Model ────────────────────────────────────────────────────────────────────
class HierarchicalSwin(nn.Module):
"""
Hierarchical Swin-Tiny model with species and disease heads.
Architecture:
Input (224×224×3)
→ Swin-Tiny backbone (timm, 768-dim features)
→ Species head (Linear: 768 → num_species)
→ Disease heads (one Linear: 768 → num_diseases_per_species)
Inference: one forward pass through backbone, then route 768-dim feature
to correct disease head based on predicted species.
"""
def __init__(self, config, species_index, disease_to_idx):
super().__init__()
self.config = config
self.species_list = sorted(species_index.keys())
self.num_species = len(self.species_list)
self.disease_to_idx = disease_to_idx
# Backbone
import timm
self.backbone = timm.create_model(
config["backbone"],
pretrained=True,
num_classes=0, # No classification head — we use features only
global_pool="avg",
)
self.feature_dim = self.backbone.num_features # 768 for Swin-Tiny
# Species head
self.species_head = nn.Sequential(
nn.Dropout(config.get("dropout", 0.1)),
nn.Linear(self.feature_dim, self.num_species),
)
# Disease heads — one per species, built dynamically
self.disease_heads = nn.ModuleDict()
for species in self.species_list:
num_diseases = len(species_index[species])
if num_diseases > 1:
self.disease_heads[species] = nn.Linear(self.feature_dim, num_diseases)
# Store num_diseases per species for routing
self.num_diseases_per_species = {
species: max(v.values()) + 1 for species, v in disease_to_idx.items()
} if disease_to_idx else {}
print(f" Model: Swin-Tiny backbone ({self.feature_dim}-dim features)")
print(f" Species head: {self.num_species} classes")
print(f" Disease heads: {len(self.disease_heads)} species-specific heads")
def forward(self, x):
"""Forward pass: return features, species_logits, and route to disease heads."""
features = self.backbone(x) # [B, 768]
species_logits = self.species_head(features) # [B, num_species]
return features, species_logits
def forward_disease(self, features, species_labels=None, species_logits=None):
"""
Disease forward pass using species-specific heads.
Args:
features: [B, 768] backbone features
species_labels: [B] ground truth species indices (for training)
species_logits: [B, num_species] predicted species scores (for inference)
Returns:
disease_logits_dict: dict mapping species name → disease logits for that sample
"""
if species_labels is not None:
# Training: use ground-truth species for routing
species_names = [self.species_list[i] for i in species_labels]
elif species_logits is not None:
# Inference: use predicted species
pred_species = species_logits.argmax(dim=1)
species_names = [self.species_list[i] for i in pred_species]
else:
raise ValueError("Need either species_labels (train) or species_logits (inference)")
batch_size = features.shape[0]
disease_logits = torch.zeros(batch_size, 1, device=features.device)
for i, species_name in enumerate(species_names):
if species_name in self.disease_heads:
head = self.disease_heads[species_name]
logits = head(features[i:i+1]) # [1, num_diseases]
if i == 0:
disease_logits = logits
else:
# Pad to match max num diseases across batch
pad_size = logits.shape[1] - disease_logits.shape[1]
if pad_size > 0:
disease_logits = F.pad(disease_logits, (0, pad_size))
elif pad_size < 0:
logits = F.pad(logits, (0, -pad_size))
disease_logits = torch.cat([disease_logits, logits], dim=0)
else:
# Species with only 1 disease (e.g. "healthy" only)
disease_logits = F.pad(
disease_logits if i > 0 else disease_logits[:0],
(0, 0, 0, 1)
)
return disease_logits, species_names
def get_trainable_params(self, stage="species_head"):
"""Get parameter groups with appropriate learning rates."""
if stage == "species_head":
# Only train species head, freeze backbone
return [
{"params": self.species_head.parameters(), "lr": self.config["species_lr_warmup"]},
]
elif stage == "species_full":
# Train all species-related params with discriminative LR
backbone_params = []
head_params = []
for name, param in self.backbone.named_parameters():
backbone_params.append(param)
for name, param in self.species_head.named_parameters():
head_params.append(param)
return [
{"params": backbone_params, "lr": self.config["species_lr_finetune"] * 0.1},
{"params": head_params, "lr": self.config["species_lr_finetune"]},
]
elif stage == "disease_heads":
# Freeze backbone, train disease heads
return [
{"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_heads"]},
]
elif stage == "disease_rare":
# Freeze backbone, train disease heads with rare-class focus
return [
{"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_rare"]},
]
elif stage == "e2e":
# End-to-end: unfreeze everything
backbone_params = []
head_params = []
disease_params = []
for name, param in self.backbone.named_parameters():
backbone_params.append(param)
for name, param in self.species_head.named_parameters():
head_params.append(param)
for name, param in self.disease_heads.named_parameters():
disease_params.append(param)
return [
{"params": backbone_params, "lr": self.config["disease_lr_e2e"] * 0.1},
{"params": head_params, "lr": self.config["disease_lr_e2e"]},
{"params": disease_params, "lr": self.config["disease_lr_e2e"]},
]
# ─── MixUp / CutMix ──────────────────────────────────────────────────────────
def mixup_data(x, y_species, y_disease, alpha=0.2):
"""Apply MixUp augmentation to a batch."""
if alpha == 0:
return x, y_species, y_disease, None, None, None
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_species_a, y_species_b = y_species, y_species[index]
y_disease_a, y_disease_b = y_disease, y_disease[index]
return mixed_x, y_species_a, y_species_b, y_disease_a, y_disease_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
"""MixUp loss: linear blend of losses between two label distributions."""
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
# ─── Training Loop ────────────────────────────────────────────────────────────
class HierarchicalTrainer:
"""Custom training loop for hierarchical model (no Lightning dependency)."""
def __init__(self, config=None):
self.config = config or DEFAULT_CONFIG.copy()
self.device = torch.device(self.config["device"])
self.seed_everything()
# Setup logging
self.log_file = LOG_DIR / "training_log.json"
LOG_DIR.mkdir(parents=True, exist_ok=True)
self.metrics_history = []
def seed_everything(self):
random.seed(self.config["seed"])
np.random.seed(self.config["seed"])
torch.manual_seed(self.config["seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config["seed"])
def prepare_data(self):
"""Load datasets and create dataloaders."""
print("\n" + "=" * 60)
print("Preparing data...")
print("=" * 60)
self.train_transform = get_train_transform(self.config)
self.val_transform = get_val_transform(self.config)
self.train_dataset = HierarchicalDataset(
DATA_DIR, "train", transform=self.train_transform, config=self.config
)
self.val_dataset = HierarchicalDataset(
DATA_DIR, "val", transform=self.val_transform, config=self.config
)
self.species_index = self.train_dataset.species_index
self.disease_to_idx = self.train_dataset.disease_to_idx
# Training dataloader with optional class-balanced sampling
species_counts = self.train_dataset.get_species_class_distribution()
# Use standard shuffle for now; weighted sampling can be added for disease heads
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.config["species_batch_size"],
shuffle=True,
num_workers=self.config["num_workers"],
prefetch_factor=self.config["prefetch_factor"],
pin_memory=True,
persistent_workers=True,
drop_last=True,
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.config["species_batch_size"] * 2,
shuffle=False,
num_workers=self.config["num_workers"],
pin_memory=True,
persistent_workers=True,
)
print(f" Train: {len(self.train_dataset)} images, {len(self.train_loader)} batches")
print(f" Val: {len(self.val_dataset)} images, {len(self.val_loader)} batches")
def create_model(self):
"""Create the hierarchical model."""
print("\n" + "=" * 60)
print("Building model...")
print("=" * 60)
self.model = HierarchicalSwin(
self.config,
self.species_index,
self.disease_to_idx,
).to(self.device)
self.species_criterion = nn.CrossEntropyLoss(
label_smoothing=self.config["label_smoothing"]
)
self.disease_criterion = nn.CrossEntropyLoss()
def save_checkpoint(self, path, stage, epoch, metrics):
"""Save model checkpoint."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
"stage": stage,
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict() if hasattr(self, 'optimizer') else None,
"scheduler_state_dict": self.scheduler.state_dict() if hasattr(self, 'scheduler') else None,
"metrics": metrics,
"config": self.config,
"species_index": self.species_index,
"disease_to_idx": self.disease_to_idx,
}
torch.save(checkpoint, path)
print(f" ✓ Checkpoint saved: {path}")
def load_checkpoint(self, path):
"""Load model from checkpoint."""
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.species_index = checkpoint.get("species_index", self.species_index)
self.disease_to_idx = checkpoint.get("disease_to_idx", self.disease_to_idx)
print(f" ✓ Loaded checkpoint: {path} (stage={checkpoint.get('stage')}, "
f"epoch={checkpoint.get('epoch')})")
return checkpoint
def log_metrics(self, stage, epoch, metrics):
"""Log metrics to file and console."""
metrics["stage"] = stage
metrics["epoch"] = epoch
self.metrics_history.append(metrics)
# Write to log file
with open(self.log_file, "w") as f:
json.dump(self.metrics_history, f, indent=2)
# Print
print(f" [{stage}] Epoch {epoch}: "
f"loss={metrics.get('train_loss', 0):.4f}, "
f"val_loss={metrics.get('val_loss', 0):.4f}, "
f"species_acc={metrics.get('val_species_acc', 0):.2%}")
# Wandb logging
if HAS_WANDB and wandb.run is not None:
wandb.log({f"{stage}/{k}": v for k, v in metrics.items()}, step=epoch)
# ─── Stage A: Species Classifier ───────────────────────────────────────
def train_species_stage(self, resume_from=None):
"""
Stage A: Train species classifier.
Sub-stages:
1. Head warmup (backbone frozen)
2. Full fine-tune (unfrozen backbone, cosine LR)
3. Final fine-tune (discriminative LR)
"""
print("\n" + "=" * 60)
print("Stage A — Species Classifier Training")
print("=" * 60)
if resume_from:
checkpoint = self.load_checkpoint(resume_from)
start_epoch = checkpoint.get("epoch", 0) + 1
else:
start_epoch = 0
# ── Sub-stage A1: Head warmup ──
if start_epoch <= self.config["species_epochs_warmup"]:
print("\n── Sub-stage A1: Head warmup (backbone frozen) ──")
params = self.model.get_trainable_params("species_head")
self._run_training_stage(
"species_warmup",
params,
self.config["species_lr_warmup"],
self.config["species_epochs_warmup"],
start_epoch,
freeze_backbone=True,
)
# ── Sub-stage A2: Full fine-tune ──
a2_start = self.config["species_epochs_warmup"]
a2_end = a2_start + self.config["species_epochs_finetune"]
if start_epoch <= a2_end:
print("\n── Sub-stage A2: Full fine-tune (cosine LR) ──")
params = self.model.get_trainable_params("species_full")
self._run_training_stage(
"species_finetune",
params,
self.config["species_lr_finetune"],
a2_end,
max(start_epoch, a2_start),
freeze_backbone=False,
cosine_schedule=True,
warmup_epochs=self.config["warmup_epochs"],
)
# ── Sub-stage A3: Final fine-tune ──
a3_start = a2_end
a3_end = a3_start + self.config["species_epochs_final"]
if start_epoch <= a3_end:
print("\n── Sub-stage A3: Final fine-tune (discriminative LR) ──")
params = self.model.get_trainable_params("species_full")
self._run_training_stage(
"species_final",
params,
self.config["species_lr_final"],
a3_end,
max(start_epoch, a3_start),
freeze_backbone=False,
)
print("\n ✓ Stage A complete!")
# ─── Stage B: Disease Classifiers ──────────────────────────────────────
def train_disease_stage(self, resume_from=None):
"""
Stage B: Train disease-specific heads.
Sub-stages:
1. All disease heads (backbone frozen)
2. Rare-class boost (oversampled rare classes)
3. End-to-end (unfrozen backbone, joint loss)
"""
print("\n" + "=" * 60)
print("Stage B — Disease Classifier Training")
print("=" * 60)
if resume_from:
self.load_checkpoint(resume_from)
# ── Sub-stage B1: Disease heads ──
print("\n── Sub-stage B1: Disease heads (backbone frozen) ──")
params = self.model.get_trainable_params("disease_heads")
self._run_training_stage(
"disease_heads",
params,
self.config["disease_lr_heads"],
self.config["disease_epochs_heads"],
freeze_backbone=True,
use_disease_loss=True,
)
# ── Sub-stage B2: Rare-class boost ──
print("\n── Sub-stage B2: Rare-class boost ──")
params = self.model.get_trainable_params("disease_rare")
self._run_training_stage(
"disease_rare",
params,
self.config["disease_lr_rare"],
self.config["disease_epochs_rare"],
freeze_backbone=True,
use_disease_loss=True,
oversample_rare=True,
)
# ── Sub-stage B3: End-to-end ──
print("\n── Sub-stage B3: End-to-end joint training ──")
params = self.model.get_trainable_params("e2e")
self._run_training_stage(
"disease_e2e",
params,
self.config["disease_lr_e2e"],
self.config["disease_epochs_e2e"],
freeze_backbone=False,
use_disease_loss=True,
)
print("\n ✓ Stage B complete!")
# ─── Core Training Loop ────────────────────────────────────────────────
def _run_training_stage(self, stage_name, param_groups, base_lr, num_epochs,
start_epoch=0, freeze_backbone=True,
cosine_schedule=False, warmup_epochs=0,
use_disease_loss=False, oversample_rare=False):
"""Run a training stage with the given parameters."""
# Freeze/unfreeze backbone
for param in self.model.backbone.parameters():
param.requires_grad = not freeze_backbone
# Setup optimizer
self.optimizer = torch.optim.AdamW(
param_groups,
weight_decay=self.config["weight_decay"],
)
# Setup scheduler
if cosine_schedule:
total_steps = num_epochs * len(self.train_loader)
warmup_steps = warmup_epochs * len(self.train_loader)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=total_steps - warmup_steps
)
# Wrap with linear warmup
from torch.optim.lr_scheduler import SequentialLR, LinearLR
warmup_scheduler = LinearLR(
self.optimizer, start_factor=0.1, total_iters=warmup_steps
)
self.scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, self.scheduler],
milestones=[warmup_steps],
)
else:
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode="min", factor=0.5, patience=3
)
# Default dataloader
train_loader = self.train_loader
# Create oversampled dataloader for rare classes if requested
if oversample_rare:
# Add per-class weighting for rare disease classes
weights = self._compute_disease_sample_weights()
if weights is not None:
sampler = WeightedRandomSampler(
weights, len(weights), replacement=True
)
train_loader = DataLoader(
self.train_dataset,
batch_size=self.config["species_batch_size"],
sampler=sampler,
num_workers=self.config["num_workers"],
prefetch_factor=self.config["prefetch_factor"],
pin_memory=True,
persistent_workers=True,
drop_last=True,
)
print(f" Using weighted sampler for rare-class boost")
# Training loop
scaler = torch.amp.GradScaler(enabled=(self.config["precision"] == "16-mixed"))
for epoch in range(start_epoch, num_epochs):
# Train one epoch
train_loss = self._train_epoch(
train_loader, use_disease_loss, scaler, stage_name
)
# Validate
val_metrics = self._validate(use_disease_loss)
# Step scheduler
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(val_metrics["val_loss"])
else:
self.scheduler.step()
# Log
metrics = {"train_loss": train_loss, **val_metrics}
self.log_metrics(stage_name, epoch, metrics)
# Save checkpoint every 5 epochs
if (epoch + 1) % 5 == 0 or (epoch + 1) == num_epochs:
ckpt_name = f"{stage_name}_epoch={epoch:02d}.pt"
self.save_checkpoint(
CHECKPOINT_DIR / stage_name / ckpt_name,
stage_name, epoch, metrics
)
def _train_epoch(self, loader, use_disease_loss, scaler, stage_name):
"""Train for one epoch."""
self.model.train()
total_loss = 0
num_batches = 0
pbar = tqdm(loader, desc=f" {stage_name}", leave=False)
for batch_idx, batch in enumerate(pbar):
images = batch["image"].to(self.device)
species_labels = batch["species_idx"].to(self.device)
with torch.amp.autocast(
device_type=self.device.type,
enabled=(self.config["precision"] == "16-mixed")
):
# Forward
features, species_logits = self.model(images)
loss_species = self.species_criterion(species_logits, species_labels)
if use_disease_loss:
disease_logits, _ = self.model.forward_disease(
features, species_labels=species_labels
)
disease_labels = batch["disease_idx"].to(self.device)
loss_disease = self.disease_criterion(disease_logits, disease_labels)
loss = loss_species + self.config["disease_loss_weight"] * loss_disease
else:
loss = loss_species
# Backward
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config["gradient_clip_val"]
)
scaler.step(self.optimizer)
scaler.update()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
return total_loss / num_batches
@torch.no_grad()
def _validate(self, use_disease_loss=False):
"""Run validation."""
self.model.eval()
total_loss = 0
correct_species = 0
total_samples = 0
correct_disease = 0
for batch in tqdm(self.val_loader, desc=" val", leave=False):
images = batch["image"].to(self.device)
species_labels = batch["species_idx"].to(self.device)
features, species_logits = self.model(images)
# Species accuracy
species_preds = species_logits.argmax(dim=1)
correct_species += (species_preds == species_labels).sum().item()
# Species loss
loss_species = self.species_criterion(species_logits, species_labels)
total_loss += loss_species.item()
# Disease accuracy (if using disease heads)
if use_disease_loss:
disease_logits, _ = self.model.forward_disease(
features, species_labels=species_labels
)
disease_labels = batch["disease_idx"].to(self.device)
# Pad disease labels to match logits width
if disease_logits.shape[1] > 1:
loss_disease = self.disease_criterion(disease_logits, disease_labels)
total_loss += self.config["disease_loss_weight"] * loss_disease.item()
disease_preds = disease_logits.argmax(dim=1)
disease_labels_clipped = disease_labels.clamp(0, disease_logits.shape[1] - 1)
correct_disease += (disease_preds == disease_labels_clipped).sum().item()
total_samples += images.size(0)
metrics = {
"val_loss": total_loss / max(len(self.val_loader), 1),
"val_species_acc": correct_species / max(total_samples, 1),
}
if use_disease_loss:
metrics["val_disease_acc"] = correct_disease / max(total_samples, 1)
return metrics
def _compute_disease_sample_weights(self):
"""Compute sample weights for rare-class oversampling."""
weights = []
for species in self.train_dataset.species_list:
class_weights = self.train_dataset.get_disease_class_weights(species)
if class_weights is None:
continue
# Map weights to each sample
for idx, (_, sp_idx, disease_name, sp) in enumerate(self.train_dataset.samples):
if sp == species:
w = class_weights.get(disease_name, 1.0)
weights.append(w)
return torch.tensor(weights, dtype=torch.float) if weights else None
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Hierarchical Swin-Tiny Training")
parser.add_argument("--stage", choices=["species", "disease", "all"], default="all",
help="Training stage to run (default: all)")
parser.add_argument("--resume", type=str, default=None,
help="Resume from checkpoint path")
parser.add_argument("--batch-size", type=int, default=None,
help="Override batch size")
parser.add_argument("--epochs", type=int, default=None,
help="Override number of epochs (overrides all stage epochs)")
parser.add_argument("--device", type=str, default=None,
help="Override device (cuda, cpu)")
parser.add_argument("--no-wandb", action="store_true",
help="Disable wandb logging")
parser.add_argument("--dry-run", action="store_true",
help="Load data and model, run 1 batch, then exit")
args = parser.parse_args()
# Setup config
config = DEFAULT_CONFIG.copy()
if args.batch_size:
config["species_batch_size"] = args.batch_size
if args.device:
config["device"] = args.device
if args.epochs:
for key in ["species_epochs_warmup", "species_epochs_finetune",
"species_epochs_final", "disease_epochs_heads",
"disease_epochs_rare", "disease_epochs_e2e"]:
config[key] = args.epochs
# Fix device name for ROCm
if config["device"] == "cuda" and torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
if hasattr(torch.version, "hip") and torch.version.hip:
print(f" ROCm version: {torch.version.hip}")
# Create trainer
trainer = HierarchicalTrainer(config)
# Prepare data
trainer.prepare_data()
# Create or load model
if args.resume:
trainer.create_model()
trainer.load_checkpoint(args.resume)
else:
trainer.create_model()
# Dry run
if args.dry_run:
print("\n── Dry run: 1 batch ──")
batch = next(iter(trainer.train_loader))
images = batch["image"].to(trainer.device)
species_labels = batch["species_idx"].to(trainer.device)
features, species_logits = trainer.model(images)
loss = trainer.species_criterion(species_logits, species_labels)
loss.backward()
print(f" Forward + backward: ✓ (loss={loss.item():.4f})")
print(f" Features shape: {features.shape}")
print(f" Species logits shape: {species_logits.shape}")
print(" Dry run complete!")
return
# Initialize wandb
if HAS_WANDB and not args.no_wandb:
wandb.init(
project="plant-disease-hierarchical",
config=config,
)
# Run stages
if args.stage in ("species", "all"):
trainer.train_species_stage(resume_from=args.resume if args.stage == "species" else None)
if args.stage in ("disease", "all"):
# For 'all', continue from species checkpoint
species_checkpoint = CHECKPOINT_DIR / "species_final" / "species_final_final.pt"
trainer.train_disease_stage(
resume_from=str(species_checkpoint) if species_checkpoint.exists() else None
)
# Close wandb
if HAS_WANDB and wandb.run is not None:
wandb.finish()
print("\n" + "=" * 60)
print("Training complete!")
print("=" * 60)
if __name__ == "__main__":
main()