1038 lines
40 KiB
Python
1038 lines
40 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Phase 2 — Hierarchical Model Training (Swin-Tiny).
|
||
|
||
Trains a hierarchical plant disease classifier with:
|
||
- Swin-Tiny backbone (timm, ImageNet-21K pretrained)
|
||
- Species head (~530 classes — identifies the plant)
|
||
- Disease heads (one per species, 30-300 classes each — identifies the disease)
|
||
|
||
Two-stage training protocol:
|
||
Stage A (Species): Train species classifier with frozen backbone, then full fine-tune.
|
||
Stage B (Disease): Train disease-specific heads, then end-to-end joint training.
|
||
|
||
Usage:
|
||
python3 scripts/train_hierarchical.py # Full training
|
||
python3 scripts/train_hierarchical.py --stage species # Stage A only
|
||
python3 scripts/train_hierarchical.py --stage disease # Stage B only (requires Stage A checkpoint)
|
||
python3 scripts/train_hierarchical.py --resume path.ckpt # Resume from checkpoint
|
||
|
||
Requirements:
|
||
pip install torch torchvision timm pytorch-lightning albumentations wandb
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
import sys
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
import albumentations as A
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from albumentations.pytorch import ToTensorV2
|
||
from PIL import Image
|
||
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
|
||
from tqdm import tqdm
|
||
|
||
# Conditionally import pytorch-lightning (optional — for multi-GPU)
|
||
try:
|
||
import pytorch_lightning as pl
|
||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||
from pytorch_lightning.loggers import WandbLogger
|
||
HAS_LIGHTNING = True
|
||
except ImportError:
|
||
HAS_LIGHTNING = False
|
||
|
||
# Conditionally import wandb
|
||
try:
|
||
import wandb
|
||
HAS_WANDB = True
|
||
except ImportError:
|
||
HAS_WANDB = False
|
||
|
||
# ─── Config ───────────────────────────────────────────────────────────────────
|
||
|
||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = BASE_DIR / "data" / "organized"
|
||
TRAIN_DIR = DATA_DIR / "train"
|
||
VAL_DIR = DATA_DIR / "val"
|
||
CHECKPOINT_DIR = BASE_DIR / "checkpoints"
|
||
LOG_DIR = BASE_DIR / "logs"
|
||
|
||
# ImageNet normalization constants
|
||
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
||
IMAGENET_STD = [0.229, 0.224, 0.225]
|
||
|
||
# Default training config
|
||
DEFAULT_CONFIG = {
|
||
# Model
|
||
"backbone": "swin_tiny_patch4_window7_224",
|
||
"img_size": 224,
|
||
"dropout": 0.1,
|
||
|
||
# Stage A (Species)
|
||
"species_epochs_warmup": 5,
|
||
"species_epochs_finetune": 20,
|
||
"species_epochs_final": 5,
|
||
"species_lr_warmup": 1e-3,
|
||
"species_lr_finetune": 1e-4,
|
||
"species_lr_final": 5e-6,
|
||
"species_batch_size": 512,
|
||
|
||
# Stage B (Disease)
|
||
"disease_epochs_heads": 15,
|
||
"disease_epochs_rare": 5,
|
||
"disease_epochs_e2e": 10,
|
||
"disease_lr_heads": 1e-3,
|
||
"disease_lr_rare": 5e-4,
|
||
"disease_lr_e2e": 1e-5,
|
||
"disease_loss_weight": 0.7,
|
||
|
||
# Data
|
||
"num_workers": 8,
|
||
"prefetch_factor": 4,
|
||
"val_split_pct": 0.15,
|
||
|
||
# Augmentation
|
||
"mixup_alpha": 0.2,
|
||
"cutmix_alpha": 1.0,
|
||
"label_smoothing": 0.1,
|
||
|
||
# Training
|
||
"gradient_clip_val": 1.0,
|
||
"weight_decay": 0.05,
|
||
"warmup_epochs": 3,
|
||
|
||
# System
|
||
"seed": 42,
|
||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||
"precision": "16-mixed", # mixed precision training
|
||
"accumulate_grad_batches": 1,
|
||
}
|
||
|
||
|
||
# ─── Dataset ──────────────────────────────────────────────────────────────────
|
||
|
||
class HierarchicalDataset(Dataset):
|
||
"""
|
||
Reads organized dataset from data/organized/{train,val}/{species}/{disease}/.
|
||
|
||
Loads species_index.json to build class mappings, then walks directories.
|
||
"""
|
||
|
||
def __init__(self, data_dir: Path, split: str, transform=None, config=None):
|
||
self.data_dir = data_dir / split
|
||
self.split = split
|
||
self.transform = transform
|
||
self.config = config or DEFAULT_CONFIG
|
||
|
||
# Load metadata
|
||
self.species_index = self._load_json(data_dir / "species_index.json")
|
||
self.dataset_stats = self._load_json(data_dir / "dataset_stats.json")
|
||
|
||
# Build mappings
|
||
self.species_list = sorted(self.species_index.keys())
|
||
self.species_to_idx = {s: i for i, s in enumerate(self.species_list)}
|
||
self.num_species = len(self.species_list)
|
||
|
||
# Build disease mappings per species
|
||
self.disease_to_idx = {} # species → {disease_name → index}
|
||
self.disease_counts = {} # species → {disease_name → count}
|
||
self.samples = [] # List of (image_path, species_idx, disease_name)
|
||
|
||
self._build_index()
|
||
|
||
def _load_json(self, path):
|
||
with open(path) as f:
|
||
return json.load(f)
|
||
|
||
def _build_index(self):
|
||
"""Walk the data directory and build sample index."""
|
||
species_dirs = sorted([
|
||
d for d in os.listdir(self.data_dir)
|
||
if os.path.isdir(self.data_dir / d) and not d.startswith(".")
|
||
])
|
||
|
||
# Add any species from species_index that exist as directories
|
||
for species in species_dirs:
|
||
if species not in self.species_to_idx:
|
||
print(f" [WARN] Species '{species}' not in species_index, skipping")
|
||
continue
|
||
|
||
species_idx = self.species_to_idx[species]
|
||
disease_dirs = sorted([
|
||
d for d in os.listdir(self.data_dir / species)
|
||
if os.path.isdir(self.data_dir / species / d) and not d.startswith(".")
|
||
])
|
||
|
||
for disease in disease_dirs:
|
||
if disease not in self.disease_to_idx.get(species, {}):
|
||
self.disease_to_idx.setdefault(species, {})[disease] = len(
|
||
self.disease_to_idx.get(species, {})
|
||
)
|
||
|
||
disease_dir = self.data_dir / species / disease
|
||
image_files = sorted([
|
||
f for f in os.listdir(disease_dir)
|
||
if os.path.isfile(disease_dir / f)
|
||
and f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
|
||
])
|
||
|
||
self.disease_counts.setdefault(species, {}).setdefault(disease, 0)
|
||
self.disease_counts[species][disease] += len(image_files)
|
||
|
||
for img_file in image_files:
|
||
self.samples.append((
|
||
disease_dir / img_file,
|
||
species_idx,
|
||
disease,
|
||
species,
|
||
))
|
||
|
||
print(f" [{self.split}] {len(self.samples)} images, "
|
||
f"{self.num_species} species")
|
||
|
||
def get_species_class_distribution(self):
|
||
"""Return list of sample counts per species."""
|
||
counts = [0] * self.num_species
|
||
for _, sp_idx, _, _ in self.samples:
|
||
counts[sp_idx] += 1
|
||
return counts
|
||
|
||
def get_disease_class_weights(self, species):
|
||
"""Return per-sample weights for disease class balancing."""
|
||
counts = self.disease_counts.get(species, {})
|
||
if not counts:
|
||
return None
|
||
total = sum(counts.values())
|
||
weights = {}
|
||
for disease, count in counts.items():
|
||
# Weight inversely proportional to sqrt(count) — moderate upweighting
|
||
weights[disease] = total / (len(counts) * np.sqrt(count) * np.sqrt(count / total))
|
||
return weights
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx):
|
||
img_path, species_idx, disease_name, species = self.samples[idx]
|
||
|
||
# Load image
|
||
image = Image.open(img_path).convert("RGB")
|
||
image_np = np.array(image)
|
||
|
||
# Apply transforms
|
||
if self.transform:
|
||
augmented = self.transform(image=image_np)
|
||
image_tensor = augmented["image"]
|
||
else:
|
||
# Default: resize + normalize
|
||
import torchvision.transforms.functional as TF
|
||
from torchvision.transforms import InterpolationMode
|
||
image = image.resize((224, 224), Image.BILINEAR)
|
||
image_tensor = TF.to_tensor(image)
|
||
image_tensor = TF.normalize(image_tensor, IMAGENET_MEAN, IMAGENET_STD)
|
||
|
||
# Disease label
|
||
disease_idx = self.disease_to_idx[species].get(disease_name, 0)
|
||
num_diseases = len(self.disease_to_idx[species])
|
||
|
||
return {
|
||
"image": image_tensor,
|
||
"species_idx": torch.tensor(species_idx, dtype=torch.long),
|
||
"disease_idx": torch.tensor(disease_idx, dtype=torch.long),
|
||
"disease_name": disease_name,
|
||
"species": species,
|
||
"num_diseases": num_diseases,
|
||
}
|
||
|
||
|
||
# ─── Augmentations ────────────────────────────────────────────────────────────
|
||
|
||
def get_train_transform(config=None):
|
||
"""Get training augmentation pipeline (Tier 1 + 2)."""
|
||
config = config or DEFAULT_CONFIG
|
||
img_size = config.get("img_size", 224)
|
||
|
||
return A.Compose([
|
||
# Tier 1 — Core geometric & photometric
|
||
A.RandomResizedCrop(size=(img_size, img_size), scale=(0.6, 1.0), ratio=(0.75, 1.33)),
|
||
A.HorizontalFlip(p=0.5),
|
||
A.Rotate(limit=45, p=0.5),
|
||
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8),
|
||
|
||
# Tier 2 — Degradation & quality simulation
|
||
A.OneOf([
|
||
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
|
||
A.GaussNoise(std_range=(0.01, 0.05), p=1.0),
|
||
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0),
|
||
], p=0.3),
|
||
A.ImageCompression(quality_range=(60, 95), p=0.2),
|
||
A.RandomShadow(p=0.15),
|
||
A.ToGray(p=0.05),
|
||
|
||
# Normalize
|
||
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||
ToTensorV2(),
|
||
])
|
||
|
||
|
||
def get_val_transform(config=None):
|
||
"""Get validation transform (minimal, deterministic)."""
|
||
config = config or DEFAULT_CONFIG
|
||
img_size = config.get("img_size", 224)
|
||
|
||
return A.Compose([
|
||
A.Resize(img_size, img_size),
|
||
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||
ToTensorV2(),
|
||
])
|
||
|
||
|
||
# ─── Model ────────────────────────────────────────────────────────────────────
|
||
|
||
class HierarchicalSwin(nn.Module):
|
||
"""
|
||
Hierarchical Swin-Tiny model with species and disease heads.
|
||
|
||
Architecture:
|
||
Input (224×224×3)
|
||
→ Swin-Tiny backbone (timm, 768-dim features)
|
||
→ Species head (Linear: 768 → num_species)
|
||
→ Disease heads (one Linear: 768 → num_diseases_per_species)
|
||
|
||
Inference: one forward pass through backbone, then route 768-dim feature
|
||
to correct disease head based on predicted species.
|
||
"""
|
||
|
||
def __init__(self, config, species_index, disease_to_idx):
|
||
super().__init__()
|
||
self.config = config
|
||
self.species_list = sorted(species_index.keys())
|
||
self.num_species = len(self.species_list)
|
||
self.disease_to_idx = disease_to_idx
|
||
|
||
# Backbone
|
||
import timm
|
||
self.backbone = timm.create_model(
|
||
config["backbone"],
|
||
pretrained=True,
|
||
num_classes=0, # No classification head — we use features only
|
||
global_pool="avg",
|
||
)
|
||
self.feature_dim = self.backbone.num_features # 768 for Swin-Tiny
|
||
|
||
# Species head
|
||
self.species_head = nn.Sequential(
|
||
nn.Dropout(config.get("dropout", 0.1)),
|
||
nn.Linear(self.feature_dim, self.num_species),
|
||
)
|
||
|
||
# Disease heads — one per species, built dynamically
|
||
self.disease_heads = nn.ModuleDict()
|
||
for species in self.species_list:
|
||
num_diseases = len(species_index[species])
|
||
if num_diseases > 1:
|
||
self.disease_heads[species] = nn.Linear(self.feature_dim, num_diseases)
|
||
|
||
# Store num_diseases per species for routing
|
||
self.num_diseases_per_species = {
|
||
species: max(v.values()) + 1 for species, v in disease_to_idx.items()
|
||
} if disease_to_idx else {}
|
||
|
||
print(f" Model: Swin-Tiny backbone ({self.feature_dim}-dim features)")
|
||
print(f" Species head: {self.num_species} classes")
|
||
print(f" Disease heads: {len(self.disease_heads)} species-specific heads")
|
||
|
||
def forward(self, x):
|
||
"""Forward pass: return features, species_logits, and route to disease heads."""
|
||
features = self.backbone(x) # [B, 768]
|
||
species_logits = self.species_head(features) # [B, num_species]
|
||
return features, species_logits
|
||
|
||
def forward_disease(self, features, species_labels=None, species_logits=None):
|
||
"""
|
||
Disease forward pass using species-specific heads.
|
||
|
||
Args:
|
||
features: [B, 768] backbone features
|
||
species_labels: [B] ground truth species indices (for training)
|
||
species_logits: [B, num_species] predicted species scores (for inference)
|
||
|
||
Returns:
|
||
disease_logits_dict: dict mapping species name → disease logits for that sample
|
||
"""
|
||
if species_labels is not None:
|
||
# Training: use ground-truth species for routing
|
||
species_names = [self.species_list[i] for i in species_labels]
|
||
elif species_logits is not None:
|
||
# Inference: use predicted species
|
||
pred_species = species_logits.argmax(dim=1)
|
||
species_names = [self.species_list[i] for i in pred_species]
|
||
else:
|
||
raise ValueError("Need either species_labels (train) or species_logits (inference)")
|
||
|
||
batch_size = features.shape[0]
|
||
disease_logits = torch.zeros(batch_size, 1, device=features.device)
|
||
|
||
for i, species_name in enumerate(species_names):
|
||
if species_name in self.disease_heads:
|
||
head = self.disease_heads[species_name]
|
||
logits = head(features[i:i+1]) # [1, num_diseases]
|
||
if i == 0:
|
||
disease_logits = logits
|
||
else:
|
||
# Pad to match max num diseases across batch
|
||
pad_size = logits.shape[1] - disease_logits.shape[1]
|
||
if pad_size > 0:
|
||
disease_logits = F.pad(disease_logits, (0, pad_size))
|
||
elif pad_size < 0:
|
||
logits = F.pad(logits, (0, -pad_size))
|
||
disease_logits = torch.cat([disease_logits, logits], dim=0)
|
||
else:
|
||
# Species with only 1 disease (e.g. "healthy" only)
|
||
disease_logits = F.pad(
|
||
disease_logits if i > 0 else disease_logits[:0],
|
||
(0, 0, 0, 1)
|
||
)
|
||
|
||
return disease_logits, species_names
|
||
|
||
def get_trainable_params(self, stage="species_head"):
|
||
"""Get parameter groups with appropriate learning rates."""
|
||
if stage == "species_head":
|
||
# Only train species head, freeze backbone
|
||
return [
|
||
{"params": self.species_head.parameters(), "lr": self.config["species_lr_warmup"]},
|
||
]
|
||
elif stage == "species_full":
|
||
# Train all species-related params with discriminative LR
|
||
backbone_params = []
|
||
head_params = []
|
||
for name, param in self.backbone.named_parameters():
|
||
backbone_params.append(param)
|
||
for name, param in self.species_head.named_parameters():
|
||
head_params.append(param)
|
||
|
||
return [
|
||
{"params": backbone_params, "lr": self.config["species_lr_finetune"] * 0.1},
|
||
{"params": head_params, "lr": self.config["species_lr_finetune"]},
|
||
]
|
||
elif stage == "disease_heads":
|
||
# Freeze backbone, train disease heads
|
||
return [
|
||
{"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_heads"]},
|
||
]
|
||
elif stage == "disease_rare":
|
||
# Freeze backbone, train disease heads with rare-class focus
|
||
return [
|
||
{"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_rare"]},
|
||
]
|
||
elif stage == "e2e":
|
||
# End-to-end: unfreeze everything
|
||
backbone_params = []
|
||
head_params = []
|
||
disease_params = []
|
||
for name, param in self.backbone.named_parameters():
|
||
backbone_params.append(param)
|
||
for name, param in self.species_head.named_parameters():
|
||
head_params.append(param)
|
||
for name, param in self.disease_heads.named_parameters():
|
||
disease_params.append(param)
|
||
|
||
return [
|
||
{"params": backbone_params, "lr": self.config["disease_lr_e2e"] * 0.1},
|
||
{"params": head_params, "lr": self.config["disease_lr_e2e"]},
|
||
{"params": disease_params, "lr": self.config["disease_lr_e2e"]},
|
||
]
|
||
|
||
|
||
# ─── MixUp / CutMix ──────────────────────────────────────────────────────────
|
||
|
||
def mixup_data(x, y_species, y_disease, alpha=0.2):
|
||
"""Apply MixUp augmentation to a batch."""
|
||
if alpha == 0:
|
||
return x, y_species, y_disease, None, None, None
|
||
|
||
lam = np.random.beta(alpha, alpha)
|
||
batch_size = x.size(0)
|
||
index = torch.randperm(batch_size).to(x.device)
|
||
|
||
mixed_x = lam * x + (1 - lam) * x[index]
|
||
y_species_a, y_species_b = y_species, y_species[index]
|
||
y_disease_a, y_disease_b = y_disease, y_disease[index]
|
||
|
||
return mixed_x, y_species_a, y_species_b, y_disease_a, y_disease_b, lam
|
||
|
||
|
||
def mixup_criterion(criterion, pred, y_a, y_b, lam):
|
||
"""MixUp loss: linear blend of losses between two label distributions."""
|
||
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
||
|
||
|
||
# ─── Training Loop ────────────────────────────────────────────────────────────
|
||
|
||
class HierarchicalTrainer:
|
||
"""Custom training loop for hierarchical model (no Lightning dependency)."""
|
||
|
||
def __init__(self, config=None):
|
||
self.config = config or DEFAULT_CONFIG.copy()
|
||
self.device = torch.device(self.config["device"])
|
||
self.seed_everything()
|
||
|
||
# Setup logging
|
||
self.log_file = LOG_DIR / "training_log.json"
|
||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||
self.metrics_history = []
|
||
|
||
def seed_everything(self):
|
||
random.seed(self.config["seed"])
|
||
np.random.seed(self.config["seed"])
|
||
torch.manual_seed(self.config["seed"])
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(self.config["seed"])
|
||
|
||
def prepare_data(self):
|
||
"""Load datasets and create dataloaders."""
|
||
print("\n" + "=" * 60)
|
||
print("Preparing data...")
|
||
print("=" * 60)
|
||
|
||
self.train_transform = get_train_transform(self.config)
|
||
self.val_transform = get_val_transform(self.config)
|
||
|
||
self.train_dataset = HierarchicalDataset(
|
||
DATA_DIR, "train", transform=self.train_transform, config=self.config
|
||
)
|
||
self.val_dataset = HierarchicalDataset(
|
||
DATA_DIR, "val", transform=self.val_transform, config=self.config
|
||
)
|
||
|
||
self.species_index = self.train_dataset.species_index
|
||
self.disease_to_idx = self.train_dataset.disease_to_idx
|
||
|
||
# Training dataloader with optional class-balanced sampling
|
||
species_counts = self.train_dataset.get_species_class_distribution()
|
||
# Use standard shuffle for now; weighted sampling can be added for disease heads
|
||
|
||
self.train_loader = DataLoader(
|
||
self.train_dataset,
|
||
batch_size=self.config["species_batch_size"],
|
||
shuffle=True,
|
||
num_workers=self.config["num_workers"],
|
||
prefetch_factor=self.config["prefetch_factor"],
|
||
pin_memory=True,
|
||
persistent_workers=True,
|
||
drop_last=True,
|
||
)
|
||
|
||
self.val_loader = DataLoader(
|
||
self.val_dataset,
|
||
batch_size=self.config["species_batch_size"] * 2,
|
||
shuffle=False,
|
||
num_workers=self.config["num_workers"],
|
||
pin_memory=True,
|
||
persistent_workers=True,
|
||
)
|
||
|
||
print(f" Train: {len(self.train_dataset)} images, {len(self.train_loader)} batches")
|
||
print(f" Val: {len(self.val_dataset)} images, {len(self.val_loader)} batches")
|
||
|
||
def create_model(self):
|
||
"""Create the hierarchical model."""
|
||
print("\n" + "=" * 60)
|
||
print("Building model...")
|
||
print("=" * 60)
|
||
|
||
self.model = HierarchicalSwin(
|
||
self.config,
|
||
self.species_index,
|
||
self.disease_to_idx,
|
||
).to(self.device)
|
||
|
||
self.species_criterion = nn.CrossEntropyLoss(
|
||
label_smoothing=self.config["label_smoothing"]
|
||
)
|
||
self.disease_criterion = nn.CrossEntropyLoss()
|
||
|
||
def save_checkpoint(self, path, stage, epoch, metrics):
|
||
"""Save model checkpoint."""
|
||
path = Path(path)
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
checkpoint = {
|
||
"stage": stage,
|
||
"epoch": epoch,
|
||
"model_state_dict": self.model.state_dict(),
|
||
"optimizer_state_dict": self.optimizer.state_dict() if hasattr(self, 'optimizer') else None,
|
||
"scheduler_state_dict": self.scheduler.state_dict() if hasattr(self, 'scheduler') else None,
|
||
"metrics": metrics,
|
||
"config": self.config,
|
||
"species_index": self.species_index,
|
||
"disease_to_idx": self.disease_to_idx,
|
||
}
|
||
torch.save(checkpoint, path)
|
||
print(f" ✓ Checkpoint saved: {path}")
|
||
|
||
def load_checkpoint(self, path):
|
||
"""Load model from checkpoint."""
|
||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
self.species_index = checkpoint.get("species_index", self.species_index)
|
||
self.disease_to_idx = checkpoint.get("disease_to_idx", self.disease_to_idx)
|
||
print(f" ✓ Loaded checkpoint: {path} (stage={checkpoint.get('stage')}, "
|
||
f"epoch={checkpoint.get('epoch')})")
|
||
return checkpoint
|
||
|
||
def log_metrics(self, stage, epoch, metrics):
|
||
"""Log metrics to file and console."""
|
||
metrics["stage"] = stage
|
||
metrics["epoch"] = epoch
|
||
self.metrics_history.append(metrics)
|
||
|
||
# Write to log file
|
||
with open(self.log_file, "w") as f:
|
||
json.dump(self.metrics_history, f, indent=2)
|
||
|
||
# Print
|
||
print(f" [{stage}] Epoch {epoch}: "
|
||
f"loss={metrics.get('train_loss', 0):.4f}, "
|
||
f"val_loss={metrics.get('val_loss', 0):.4f}, "
|
||
f"species_acc={metrics.get('val_species_acc', 0):.2%}")
|
||
|
||
# Wandb logging
|
||
if HAS_WANDB and wandb.run is not None:
|
||
wandb.log({f"{stage}/{k}": v for k, v in metrics.items()}, step=epoch)
|
||
|
||
# ─── Stage A: Species Classifier ───────────────────────────────────────
|
||
|
||
def train_species_stage(self, resume_from=None):
|
||
"""
|
||
Stage A: Train species classifier.
|
||
|
||
Sub-stages:
|
||
1. Head warmup (backbone frozen)
|
||
2. Full fine-tune (unfrozen backbone, cosine LR)
|
||
3. Final fine-tune (discriminative LR)
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("Stage A — Species Classifier Training")
|
||
print("=" * 60)
|
||
|
||
if resume_from:
|
||
checkpoint = self.load_checkpoint(resume_from)
|
||
start_epoch = checkpoint.get("epoch", 0) + 1
|
||
else:
|
||
start_epoch = 0
|
||
|
||
# ── Sub-stage A1: Head warmup ──
|
||
if start_epoch <= self.config["species_epochs_warmup"]:
|
||
print("\n── Sub-stage A1: Head warmup (backbone frozen) ──")
|
||
params = self.model.get_trainable_params("species_head")
|
||
self._run_training_stage(
|
||
"species_warmup",
|
||
params,
|
||
self.config["species_lr_warmup"],
|
||
self.config["species_epochs_warmup"],
|
||
start_epoch,
|
||
freeze_backbone=True,
|
||
)
|
||
|
||
# ── Sub-stage A2: Full fine-tune ──
|
||
a2_start = self.config["species_epochs_warmup"]
|
||
a2_end = a2_start + self.config["species_epochs_finetune"]
|
||
if start_epoch <= a2_end:
|
||
print("\n── Sub-stage A2: Full fine-tune (cosine LR) ──")
|
||
params = self.model.get_trainable_params("species_full")
|
||
self._run_training_stage(
|
||
"species_finetune",
|
||
params,
|
||
self.config["species_lr_finetune"],
|
||
a2_end,
|
||
max(start_epoch, a2_start),
|
||
freeze_backbone=False,
|
||
cosine_schedule=True,
|
||
warmup_epochs=self.config["warmup_epochs"],
|
||
)
|
||
|
||
# ── Sub-stage A3: Final fine-tune ──
|
||
a3_start = a2_end
|
||
a3_end = a3_start + self.config["species_epochs_final"]
|
||
if start_epoch <= a3_end:
|
||
print("\n── Sub-stage A3: Final fine-tune (discriminative LR) ──")
|
||
params = self.model.get_trainable_params("species_full")
|
||
self._run_training_stage(
|
||
"species_final",
|
||
params,
|
||
self.config["species_lr_final"],
|
||
a3_end,
|
||
max(start_epoch, a3_start),
|
||
freeze_backbone=False,
|
||
)
|
||
|
||
print("\n ✓ Stage A complete!")
|
||
|
||
# ─── Stage B: Disease Classifiers ──────────────────────────────────────
|
||
|
||
def train_disease_stage(self, resume_from=None):
|
||
"""
|
||
Stage B: Train disease-specific heads.
|
||
|
||
Sub-stages:
|
||
1. All disease heads (backbone frozen)
|
||
2. Rare-class boost (oversampled rare classes)
|
||
3. End-to-end (unfrozen backbone, joint loss)
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("Stage B — Disease Classifier Training")
|
||
print("=" * 60)
|
||
|
||
if resume_from:
|
||
self.load_checkpoint(resume_from)
|
||
|
||
# ── Sub-stage B1: Disease heads ──
|
||
print("\n── Sub-stage B1: Disease heads (backbone frozen) ──")
|
||
params = self.model.get_trainable_params("disease_heads")
|
||
self._run_training_stage(
|
||
"disease_heads",
|
||
params,
|
||
self.config["disease_lr_heads"],
|
||
self.config["disease_epochs_heads"],
|
||
freeze_backbone=True,
|
||
use_disease_loss=True,
|
||
)
|
||
|
||
# ── Sub-stage B2: Rare-class boost ──
|
||
print("\n── Sub-stage B2: Rare-class boost ──")
|
||
params = self.model.get_trainable_params("disease_rare")
|
||
self._run_training_stage(
|
||
"disease_rare",
|
||
params,
|
||
self.config["disease_lr_rare"],
|
||
self.config["disease_epochs_rare"],
|
||
freeze_backbone=True,
|
||
use_disease_loss=True,
|
||
oversample_rare=True,
|
||
)
|
||
|
||
# ── Sub-stage B3: End-to-end ──
|
||
print("\n── Sub-stage B3: End-to-end joint training ──")
|
||
params = self.model.get_trainable_params("e2e")
|
||
self._run_training_stage(
|
||
"disease_e2e",
|
||
params,
|
||
self.config["disease_lr_e2e"],
|
||
self.config["disease_epochs_e2e"],
|
||
freeze_backbone=False,
|
||
use_disease_loss=True,
|
||
)
|
||
|
||
print("\n ✓ Stage B complete!")
|
||
|
||
# ─── Core Training Loop ────────────────────────────────────────────────
|
||
|
||
def _run_training_stage(self, stage_name, param_groups, base_lr, num_epochs,
|
||
start_epoch=0, freeze_backbone=True,
|
||
cosine_schedule=False, warmup_epochs=0,
|
||
use_disease_loss=False, oversample_rare=False):
|
||
"""Run a training stage with the given parameters."""
|
||
|
||
# Freeze/unfreeze backbone
|
||
for param in self.model.backbone.parameters():
|
||
param.requires_grad = not freeze_backbone
|
||
|
||
# Setup optimizer
|
||
self.optimizer = torch.optim.AdamW(
|
||
param_groups,
|
||
weight_decay=self.config["weight_decay"],
|
||
)
|
||
|
||
# Setup scheduler
|
||
if cosine_schedule:
|
||
total_steps = num_epochs * len(self.train_loader)
|
||
warmup_steps = warmup_epochs * len(self.train_loader)
|
||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||
self.optimizer, T_max=total_steps - warmup_steps
|
||
)
|
||
# Wrap with linear warmup
|
||
from torch.optim.lr_scheduler import SequentialLR, LinearLR
|
||
warmup_scheduler = LinearLR(
|
||
self.optimizer, start_factor=0.1, total_iters=warmup_steps
|
||
)
|
||
self.scheduler = SequentialLR(
|
||
self.optimizer,
|
||
schedulers=[warmup_scheduler, self.scheduler],
|
||
milestones=[warmup_steps],
|
||
)
|
||
else:
|
||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
self.optimizer, mode="min", factor=0.5, patience=3
|
||
)
|
||
|
||
# Default dataloader
|
||
train_loader = self.train_loader
|
||
|
||
# Create oversampled dataloader for rare classes if requested
|
||
if oversample_rare:
|
||
# Add per-class weighting for rare disease classes
|
||
weights = self._compute_disease_sample_weights()
|
||
if weights is not None:
|
||
sampler = WeightedRandomSampler(
|
||
weights, len(weights), replacement=True
|
||
)
|
||
train_loader = DataLoader(
|
||
self.train_dataset,
|
||
batch_size=self.config["species_batch_size"],
|
||
sampler=sampler,
|
||
num_workers=self.config["num_workers"],
|
||
prefetch_factor=self.config["prefetch_factor"],
|
||
pin_memory=True,
|
||
persistent_workers=True,
|
||
drop_last=True,
|
||
)
|
||
print(f" Using weighted sampler for rare-class boost")
|
||
|
||
# Training loop
|
||
scaler = torch.amp.GradScaler(enabled=(self.config["precision"] == "16-mixed"))
|
||
|
||
for epoch in range(start_epoch, num_epochs):
|
||
# Train one epoch
|
||
train_loss = self._train_epoch(
|
||
train_loader, use_disease_loss, scaler, stage_name
|
||
)
|
||
|
||
# Validate
|
||
val_metrics = self._validate(use_disease_loss)
|
||
|
||
# Step scheduler
|
||
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||
self.scheduler.step(val_metrics["val_loss"])
|
||
else:
|
||
self.scheduler.step()
|
||
|
||
# Log
|
||
metrics = {"train_loss": train_loss, **val_metrics}
|
||
self.log_metrics(stage_name, epoch, metrics)
|
||
|
||
# Save checkpoint every 5 epochs
|
||
if (epoch + 1) % 5 == 0 or (epoch + 1) == num_epochs:
|
||
ckpt_name = f"{stage_name}_epoch={epoch:02d}.pt"
|
||
self.save_checkpoint(
|
||
CHECKPOINT_DIR / stage_name / ckpt_name,
|
||
stage_name, epoch, metrics
|
||
)
|
||
|
||
def _train_epoch(self, loader, use_disease_loss, scaler, stage_name):
|
||
"""Train for one epoch."""
|
||
self.model.train()
|
||
total_loss = 0
|
||
num_batches = 0
|
||
|
||
pbar = tqdm(loader, desc=f" {stage_name}", leave=False)
|
||
for batch_idx, batch in enumerate(pbar):
|
||
images = batch["image"].to(self.device)
|
||
species_labels = batch["species_idx"].to(self.device)
|
||
|
||
with torch.amp.autocast(
|
||
device_type=self.device.type,
|
||
enabled=(self.config["precision"] == "16-mixed")
|
||
):
|
||
# Forward
|
||
features, species_logits = self.model(images)
|
||
loss_species = self.species_criterion(species_logits, species_labels)
|
||
|
||
if use_disease_loss:
|
||
disease_logits, _ = self.model.forward_disease(
|
||
features, species_labels=species_labels
|
||
)
|
||
disease_labels = batch["disease_idx"].to(self.device)
|
||
loss_disease = self.disease_criterion(disease_logits, disease_labels)
|
||
loss = loss_species + self.config["disease_loss_weight"] * loss_disease
|
||
else:
|
||
loss = loss_species
|
||
|
||
# Backward
|
||
self.optimizer.zero_grad()
|
||
scaler.scale(loss).backward()
|
||
scaler.unscale_(self.optimizer)
|
||
torch.nn.utils.clip_grad_norm_(
|
||
self.model.parameters(), self.config["gradient_clip_val"]
|
||
)
|
||
scaler.step(self.optimizer)
|
||
scaler.update()
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
|
||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||
|
||
return total_loss / num_batches
|
||
|
||
@torch.no_grad()
|
||
def _validate(self, use_disease_loss=False):
|
||
"""Run validation."""
|
||
self.model.eval()
|
||
total_loss = 0
|
||
correct_species = 0
|
||
total_samples = 0
|
||
correct_disease = 0
|
||
|
||
for batch in tqdm(self.val_loader, desc=" val", leave=False):
|
||
images = batch["image"].to(self.device)
|
||
species_labels = batch["species_idx"].to(self.device)
|
||
|
||
features, species_logits = self.model(images)
|
||
|
||
# Species accuracy
|
||
species_preds = species_logits.argmax(dim=1)
|
||
correct_species += (species_preds == species_labels).sum().item()
|
||
|
||
# Species loss
|
||
loss_species = self.species_criterion(species_logits, species_labels)
|
||
total_loss += loss_species.item()
|
||
|
||
# Disease accuracy (if using disease heads)
|
||
if use_disease_loss:
|
||
disease_logits, _ = self.model.forward_disease(
|
||
features, species_labels=species_labels
|
||
)
|
||
disease_labels = batch["disease_idx"].to(self.device)
|
||
|
||
# Pad disease labels to match logits width
|
||
if disease_logits.shape[1] > 1:
|
||
loss_disease = self.disease_criterion(disease_logits, disease_labels)
|
||
total_loss += self.config["disease_loss_weight"] * loss_disease.item()
|
||
|
||
disease_preds = disease_logits.argmax(dim=1)
|
||
disease_labels_clipped = disease_labels.clamp(0, disease_logits.shape[1] - 1)
|
||
correct_disease += (disease_preds == disease_labels_clipped).sum().item()
|
||
|
||
total_samples += images.size(0)
|
||
|
||
metrics = {
|
||
"val_loss": total_loss / max(len(self.val_loader), 1),
|
||
"val_species_acc": correct_species / max(total_samples, 1),
|
||
}
|
||
|
||
if use_disease_loss:
|
||
metrics["val_disease_acc"] = correct_disease / max(total_samples, 1)
|
||
|
||
return metrics
|
||
|
||
def _compute_disease_sample_weights(self):
|
||
"""Compute sample weights for rare-class oversampling."""
|
||
weights = []
|
||
for species in self.train_dataset.species_list:
|
||
class_weights = self.train_dataset.get_disease_class_weights(species)
|
||
if class_weights is None:
|
||
continue
|
||
# Map weights to each sample
|
||
for idx, (_, sp_idx, disease_name, sp) in enumerate(self.train_dataset.samples):
|
||
if sp == species:
|
||
w = class_weights.get(disease_name, 1.0)
|
||
weights.append(w)
|
||
|
||
return torch.tensor(weights, dtype=torch.float) if weights else None
|
||
|
||
|
||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Hierarchical Swin-Tiny Training")
|
||
parser.add_argument("--stage", choices=["species", "disease", "all"], default="all",
|
||
help="Training stage to run (default: all)")
|
||
parser.add_argument("--resume", type=str, default=None,
|
||
help="Resume from checkpoint path")
|
||
parser.add_argument("--batch-size", type=int, default=None,
|
||
help="Override batch size")
|
||
parser.add_argument("--epochs", type=int, default=None,
|
||
help="Override number of epochs (overrides all stage epochs)")
|
||
parser.add_argument("--device", type=str, default=None,
|
||
help="Override device (cuda, cpu)")
|
||
parser.add_argument("--no-wandb", action="store_true",
|
||
help="Disable wandb logging")
|
||
parser.add_argument("--dry-run", action="store_true",
|
||
help="Load data and model, run 1 batch, then exit")
|
||
args = parser.parse_args()
|
||
|
||
# Setup config
|
||
config = DEFAULT_CONFIG.copy()
|
||
if args.batch_size:
|
||
config["species_batch_size"] = args.batch_size
|
||
if args.device:
|
||
config["device"] = args.device
|
||
if args.epochs:
|
||
for key in ["species_epochs_warmup", "species_epochs_finetune",
|
||
"species_epochs_final", "disease_epochs_heads",
|
||
"disease_epochs_rare", "disease_epochs_e2e"]:
|
||
config[key] = args.epochs
|
||
|
||
# Fix device name for ROCm
|
||
if config["device"] == "cuda" and torch.cuda.is_available():
|
||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
||
if hasattr(torch.version, "hip") and torch.version.hip:
|
||
print(f" ROCm version: {torch.version.hip}")
|
||
|
||
# Create trainer
|
||
trainer = HierarchicalTrainer(config)
|
||
|
||
# Prepare data
|
||
trainer.prepare_data()
|
||
|
||
# Create or load model
|
||
if args.resume:
|
||
trainer.create_model()
|
||
trainer.load_checkpoint(args.resume)
|
||
else:
|
||
trainer.create_model()
|
||
|
||
# Dry run
|
||
if args.dry_run:
|
||
print("\n── Dry run: 1 batch ──")
|
||
batch = next(iter(trainer.train_loader))
|
||
images = batch["image"].to(trainer.device)
|
||
species_labels = batch["species_idx"].to(trainer.device)
|
||
features, species_logits = trainer.model(images)
|
||
loss = trainer.species_criterion(species_logits, species_labels)
|
||
loss.backward()
|
||
print(f" Forward + backward: ✓ (loss={loss.item():.4f})")
|
||
print(f" Features shape: {features.shape}")
|
||
print(f" Species logits shape: {species_logits.shape}")
|
||
print(" Dry run complete!")
|
||
return
|
||
|
||
# Initialize wandb
|
||
if HAS_WANDB and not args.no_wandb:
|
||
wandb.init(
|
||
project="plant-disease-hierarchical",
|
||
config=config,
|
||
)
|
||
|
||
# Run stages
|
||
if args.stage in ("species", "all"):
|
||
trainer.train_species_stage(resume_from=args.resume if args.stage == "species" else None)
|
||
|
||
if args.stage in ("disease", "all"):
|
||
# For 'all', continue from species checkpoint
|
||
species_checkpoint = CHECKPOINT_DIR / "species_final" / "species_final_final.pt"
|
||
trainer.train_disease_stage(
|
||
resume_from=str(species_checkpoint) if species_checkpoint.exists() else None
|
||
)
|
||
|
||
# Close wandb
|
||
if HAS_WANDB and wandb.run is not None:
|
||
wandb.finish()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Training complete!")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|