#!/usr/bin/env python3 """ Phase 2 — Hierarchical Model Training (Swin-Tiny). Trains a hierarchical plant disease classifier with: - Swin-Tiny backbone (timm, ImageNet-21K pretrained) - Species head (~530 classes — identifies the plant) - Disease heads (one per species, 30-300 classes each — identifies the disease) Two-stage training protocol: Stage A (Species): Train species classifier with frozen backbone, then full fine-tune. Stage B (Disease): Train disease-specific heads, then end-to-end joint training. Usage: python3 scripts/train_hierarchical.py # Full training python3 scripts/train_hierarchical.py --stage species # Stage A only python3 scripts/train_hierarchical.py --stage disease # Stage B only (requires Stage A checkpoint) python3 scripts/train_hierarchical.py --resume path.ckpt # Resume from checkpoint Requirements: pip install torch torchvision timm pytorch-lightning albumentations wandb """ import argparse import json import os import random import sys from collections import defaultdict from pathlib import Path import albumentations as A import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from albumentations.pytorch import ToTensorV2 from PIL import Image from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler from tqdm import tqdm # Conditionally import pytorch-lightning (optional — for multi-GPU) try: import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import WandbLogger HAS_LIGHTNING = True except ImportError: HAS_LIGHTNING = False # Conditionally import wandb try: import wandb HAS_WANDB = True except ImportError: HAS_WANDB = False # ─── Config ─────────────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).resolve().parent.parent DATA_DIR = BASE_DIR / "data" / "organized" TRAIN_DIR = DATA_DIR / "train" VAL_DIR = DATA_DIR / "val" CHECKPOINT_DIR = BASE_DIR / "checkpoints" LOG_DIR = BASE_DIR / "logs" # ImageNet normalization constants IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Default training config DEFAULT_CONFIG = { # Model "backbone": "swin_tiny_patch4_window7_224", "img_size": 224, "dropout": 0.1, # Stage A (Species) "species_epochs_warmup": 5, "species_epochs_finetune": 20, "species_epochs_final": 5, "species_lr_warmup": 1e-3, "species_lr_finetune": 1e-4, "species_lr_final": 5e-6, "species_batch_size": 512, # Stage B (Disease) "disease_epochs_heads": 15, "disease_epochs_rare": 5, "disease_epochs_e2e": 10, "disease_lr_heads": 1e-3, "disease_lr_rare": 5e-4, "disease_lr_e2e": 1e-5, "disease_loss_weight": 0.7, # Data "num_workers": 8, "prefetch_factor": 4, "val_split_pct": 0.15, # Augmentation "mixup_alpha": 0.2, "cutmix_alpha": 1.0, "label_smoothing": 0.1, # Training "gradient_clip_val": 1.0, "weight_decay": 0.05, "warmup_epochs": 3, # System "seed": 42, "device": "cuda" if torch.cuda.is_available() else "cpu", "precision": "16-mixed", # mixed precision training "accumulate_grad_batches": 1, } # ─── Dataset ────────────────────────────────────────────────────────────────── class HierarchicalDataset(Dataset): """ Reads organized dataset from data/organized/{train,val}/{species}/{disease}/. Loads species_index.json to build class mappings, then walks directories. """ def __init__(self, data_dir: Path, split: str, transform=None, config=None): self.data_dir = data_dir / split self.split = split self.transform = transform self.config = config or DEFAULT_CONFIG # Load metadata self.species_index = self._load_json(data_dir / "species_index.json") self.dataset_stats = self._load_json(data_dir / "dataset_stats.json") # Build mappings self.species_list = sorted(self.species_index.keys()) self.species_to_idx = {s: i for i, s in enumerate(self.species_list)} self.num_species = len(self.species_list) # Build disease mappings per species self.disease_to_idx = {} # species → {disease_name → index} self.disease_counts = {} # species → {disease_name → count} self.samples = [] # List of (image_path, species_idx, disease_name) self._build_index() def _load_json(self, path): with open(path) as f: return json.load(f) def _build_index(self): """Walk the data directory and build sample index.""" species_dirs = sorted([ d for d in os.listdir(self.data_dir) if os.path.isdir(self.data_dir / d) and not d.startswith(".") ]) # Add any species from species_index that exist as directories for species in species_dirs: if species not in self.species_to_idx: print(f" [WARN] Species '{species}' not in species_index, skipping") continue species_idx = self.species_to_idx[species] disease_dirs = sorted([ d for d in os.listdir(self.data_dir / species) if os.path.isdir(self.data_dir / species / d) and not d.startswith(".") ]) for disease in disease_dirs: if disease not in self.disease_to_idx.get(species, {}): self.disease_to_idx.setdefault(species, {})[disease] = len( self.disease_to_idx.get(species, {}) ) disease_dir = self.data_dir / species / disease image_files = sorted([ f for f in os.listdir(disease_dir) if os.path.isfile(disease_dir / f) and f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")) ]) self.disease_counts.setdefault(species, {}).setdefault(disease, 0) self.disease_counts[species][disease] += len(image_files) for img_file in image_files: self.samples.append(( disease_dir / img_file, species_idx, disease, species, )) print(f" [{self.split}] {len(self.samples)} images, " f"{self.num_species} species") def get_species_class_distribution(self): """Return list of sample counts per species.""" counts = [0] * self.num_species for _, sp_idx, _, _ in self.samples: counts[sp_idx] += 1 return counts def get_disease_class_weights(self, species): """Return per-sample weights for disease class balancing.""" counts = self.disease_counts.get(species, {}) if not counts: return None total = sum(counts.values()) weights = {} for disease, count in counts.items(): # Weight inversely proportional to sqrt(count) — moderate upweighting weights[disease] = total / (len(counts) * np.sqrt(count) * np.sqrt(count / total)) return weights def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, species_idx, disease_name, species = self.samples[idx] # Load image image = Image.open(img_path).convert("RGB") image_np = np.array(image) # Apply transforms if self.transform: augmented = self.transform(image=image_np) image_tensor = augmented["image"] else: # Default: resize + normalize import torchvision.transforms.functional as TF from torchvision.transforms import InterpolationMode image = image.resize((224, 224), Image.BILINEAR) image_tensor = TF.to_tensor(image) image_tensor = TF.normalize(image_tensor, IMAGENET_MEAN, IMAGENET_STD) # Disease label disease_idx = self.disease_to_idx[species].get(disease_name, 0) num_diseases = len(self.disease_to_idx[species]) return { "image": image_tensor, "species_idx": torch.tensor(species_idx, dtype=torch.long), "disease_idx": torch.tensor(disease_idx, dtype=torch.long), "disease_name": disease_name, "species": species, "num_diseases": num_diseases, } # ─── Augmentations ──────────────────────────────────────────────────────────── def get_train_transform(config=None): """Get training augmentation pipeline (Tier 1 + 2).""" config = config or DEFAULT_CONFIG img_size = config.get("img_size", 224) return A.Compose([ # Tier 1 — Core geometric & photometric A.RandomResizedCrop(size=(img_size, img_size), scale=(0.6, 1.0), ratio=(0.75, 1.33)), A.HorizontalFlip(p=0.5), A.Rotate(limit=45, p=0.5), A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.8), # Tier 2 — Degradation & quality simulation A.OneOf([ A.GaussianBlur(blur_limit=(3, 7), p=1.0), A.GaussNoise(std_range=(0.01, 0.05), p=1.0), A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=1.0), ], p=0.3), A.ImageCompression(quality_range=(60, 95), p=0.2), A.RandomShadow(p=0.15), A.ToGray(p=0.05), # Normalize A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2(), ]) def get_val_transform(config=None): """Get validation transform (minimal, deterministic).""" config = config or DEFAULT_CONFIG img_size = config.get("img_size", 224) return A.Compose([ A.Resize(img_size, img_size), A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2(), ]) # ─── Model ──────────────────────────────────────────────────────────────────── class HierarchicalSwin(nn.Module): """ Hierarchical Swin-Tiny model with species and disease heads. Architecture: Input (224×224×3) → Swin-Tiny backbone (timm, 768-dim features) → Species head (Linear: 768 → num_species) → Disease heads (one Linear: 768 → num_diseases_per_species) Inference: one forward pass through backbone, then route 768-dim feature to correct disease head based on predicted species. """ def __init__(self, config, species_index, disease_to_idx): super().__init__() self.config = config self.species_list = sorted(species_index.keys()) self.num_species = len(self.species_list) self.disease_to_idx = disease_to_idx # Backbone import timm self.backbone = timm.create_model( config["backbone"], pretrained=True, num_classes=0, # No classification head — we use features only global_pool="avg", ) self.feature_dim = self.backbone.num_features # 768 for Swin-Tiny # Species head self.species_head = nn.Sequential( nn.Dropout(config.get("dropout", 0.1)), nn.Linear(self.feature_dim, self.num_species), ) # Disease heads — one per species, built dynamically self.disease_heads = nn.ModuleDict() for species in self.species_list: num_diseases = len(species_index[species]) if num_diseases > 1: self.disease_heads[species] = nn.Linear(self.feature_dim, num_diseases) # Store num_diseases per species for routing self.num_diseases_per_species = { species: max(v.values()) + 1 for species, v in disease_to_idx.items() } if disease_to_idx else {} print(f" Model: Swin-Tiny backbone ({self.feature_dim}-dim features)") print(f" Species head: {self.num_species} classes") print(f" Disease heads: {len(self.disease_heads)} species-specific heads") def forward(self, x): """Forward pass: return features, species_logits, and route to disease heads.""" features = self.backbone(x) # [B, 768] species_logits = self.species_head(features) # [B, num_species] return features, species_logits def forward_disease(self, features, species_labels=None, species_logits=None): """ Disease forward pass using species-specific heads. Args: features: [B, 768] backbone features species_labels: [B] ground truth species indices (for training) species_logits: [B, num_species] predicted species scores (for inference) Returns: disease_logits_dict: dict mapping species name → disease logits for that sample """ if species_labels is not None: # Training: use ground-truth species for routing species_names = [self.species_list[i] for i in species_labels] elif species_logits is not None: # Inference: use predicted species pred_species = species_logits.argmax(dim=1) species_names = [self.species_list[i] for i in pred_species] else: raise ValueError("Need either species_labels (train) or species_logits (inference)") batch_size = features.shape[0] disease_logits = torch.zeros(batch_size, 1, device=features.device) for i, species_name in enumerate(species_names): if species_name in self.disease_heads: head = self.disease_heads[species_name] logits = head(features[i:i+1]) # [1, num_diseases] if i == 0: disease_logits = logits else: # Pad to match max num diseases across batch pad_size = logits.shape[1] - disease_logits.shape[1] if pad_size > 0: disease_logits = F.pad(disease_logits, (0, pad_size)) elif pad_size < 0: logits = F.pad(logits, (0, -pad_size)) disease_logits = torch.cat([disease_logits, logits], dim=0) else: # Species with only 1 disease (e.g. "healthy" only) disease_logits = F.pad( disease_logits if i > 0 else disease_logits[:0], (0, 0, 0, 1) ) return disease_logits, species_names def get_trainable_params(self, stage="species_head"): """Get parameter groups with appropriate learning rates.""" if stage == "species_head": # Only train species head, freeze backbone return [ {"params": self.species_head.parameters(), "lr": self.config["species_lr_warmup"]}, ] elif stage == "species_full": # Train all species-related params with discriminative LR backbone_params = [] head_params = [] for name, param in self.backbone.named_parameters(): backbone_params.append(param) for name, param in self.species_head.named_parameters(): head_params.append(param) return [ {"params": backbone_params, "lr": self.config["species_lr_finetune"] * 0.1}, {"params": head_params, "lr": self.config["species_lr_finetune"]}, ] elif stage == "disease_heads": # Freeze backbone, train disease heads return [ {"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_heads"]}, ] elif stage == "disease_rare": # Freeze backbone, train disease heads with rare-class focus return [ {"params": self.disease_heads.parameters(), "lr": self.config["disease_lr_rare"]}, ] elif stage == "e2e": # End-to-end: unfreeze everything backbone_params = [] head_params = [] disease_params = [] for name, param in self.backbone.named_parameters(): backbone_params.append(param) for name, param in self.species_head.named_parameters(): head_params.append(param) for name, param in self.disease_heads.named_parameters(): disease_params.append(param) return [ {"params": backbone_params, "lr": self.config["disease_lr_e2e"] * 0.1}, {"params": head_params, "lr": self.config["disease_lr_e2e"]}, {"params": disease_params, "lr": self.config["disease_lr_e2e"]}, ] # ─── MixUp / CutMix ────────────────────────────────────────────────────────── def mixup_data(x, y_species, y_disease, alpha=0.2): """Apply MixUp augmentation to a batch.""" if alpha == 0: return x, y_species, y_disease, None, None, None lam = np.random.beta(alpha, alpha) batch_size = x.size(0) index = torch.randperm(batch_size).to(x.device) mixed_x = lam * x + (1 - lam) * x[index] y_species_a, y_species_b = y_species, y_species[index] y_disease_a, y_disease_b = y_disease, y_disease[index] return mixed_x, y_species_a, y_species_b, y_disease_a, y_disease_b, lam def mixup_criterion(criterion, pred, y_a, y_b, lam): """MixUp loss: linear blend of losses between two label distributions.""" return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) # ─── Training Loop ──────────────────────────────────────────────────────────── class HierarchicalTrainer: """Custom training loop for hierarchical model (no Lightning dependency).""" def __init__(self, config=None): self.config = config or DEFAULT_CONFIG.copy() self.device = torch.device(self.config["device"]) self.seed_everything() # Setup logging self.log_file = LOG_DIR / "training_log.json" LOG_DIR.mkdir(parents=True, exist_ok=True) self.metrics_history = [] def seed_everything(self): random.seed(self.config["seed"]) np.random.seed(self.config["seed"]) torch.manual_seed(self.config["seed"]) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.config["seed"]) def prepare_data(self): """Load datasets and create dataloaders.""" print("\n" + "=" * 60) print("Preparing data...") print("=" * 60) self.train_transform = get_train_transform(self.config) self.val_transform = get_val_transform(self.config) self.train_dataset = HierarchicalDataset( DATA_DIR, "train", transform=self.train_transform, config=self.config ) self.val_dataset = HierarchicalDataset( DATA_DIR, "val", transform=self.val_transform, config=self.config ) self.species_index = self.train_dataset.species_index self.disease_to_idx = self.train_dataset.disease_to_idx # Training dataloader with optional class-balanced sampling species_counts = self.train_dataset.get_species_class_distribution() # Use standard shuffle for now; weighted sampling can be added for disease heads self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["species_batch_size"], shuffle=True, num_workers=self.config["num_workers"], prefetch_factor=self.config["prefetch_factor"], pin_memory=True, persistent_workers=True, drop_last=True, ) self.val_loader = DataLoader( self.val_dataset, batch_size=self.config["species_batch_size"] * 2, shuffle=False, num_workers=self.config["num_workers"], pin_memory=True, persistent_workers=True, ) print(f" Train: {len(self.train_dataset)} images, {len(self.train_loader)} batches") print(f" Val: {len(self.val_dataset)} images, {len(self.val_loader)} batches") def create_model(self): """Create the hierarchical model.""" print("\n" + "=" * 60) print("Building model...") print("=" * 60) self.model = HierarchicalSwin( self.config, self.species_index, self.disease_to_idx, ).to(self.device) self.species_criterion = nn.CrossEntropyLoss( label_smoothing=self.config["label_smoothing"] ) self.disease_criterion = nn.CrossEntropyLoss() def save_checkpoint(self, path, stage, epoch, metrics): """Save model checkpoint.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) checkpoint = { "stage": stage, "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict() if hasattr(self, 'optimizer') else None, "scheduler_state_dict": self.scheduler.state_dict() if hasattr(self, 'scheduler') else None, "metrics": metrics, "config": self.config, "species_index": self.species_index, "disease_to_idx": self.disease_to_idx, } torch.save(checkpoint, path) print(f" ✓ Checkpoint saved: {path}") def load_checkpoint(self, path): """Load model from checkpoint.""" checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint["model_state_dict"]) self.species_index = checkpoint.get("species_index", self.species_index) self.disease_to_idx = checkpoint.get("disease_to_idx", self.disease_to_idx) print(f" ✓ Loaded checkpoint: {path} (stage={checkpoint.get('stage')}, " f"epoch={checkpoint.get('epoch')})") return checkpoint def log_metrics(self, stage, epoch, metrics): """Log metrics to file and console.""" metrics["stage"] = stage metrics["epoch"] = epoch self.metrics_history.append(metrics) # Write to log file with open(self.log_file, "w") as f: json.dump(self.metrics_history, f, indent=2) # Print print(f" [{stage}] Epoch {epoch}: " f"loss={metrics.get('train_loss', 0):.4f}, " f"val_loss={metrics.get('val_loss', 0):.4f}, " f"species_acc={metrics.get('val_species_acc', 0):.2%}") # Wandb logging if HAS_WANDB and wandb.run is not None: wandb.log({f"{stage}/{k}": v for k, v in metrics.items()}, step=epoch) # ─── Stage A: Species Classifier ─────────────────────────────────────── def train_species_stage(self, resume_from=None): """ Stage A: Train species classifier. Sub-stages: 1. Head warmup (backbone frozen) 2. Full fine-tune (unfrozen backbone, cosine LR) 3. Final fine-tune (discriminative LR) """ print("\n" + "=" * 60) print("Stage A — Species Classifier Training") print("=" * 60) if resume_from: checkpoint = self.load_checkpoint(resume_from) start_epoch = checkpoint.get("epoch", 0) + 1 else: start_epoch = 0 # ── Sub-stage A1: Head warmup ── if start_epoch <= self.config["species_epochs_warmup"]: print("\n── Sub-stage A1: Head warmup (backbone frozen) ──") params = self.model.get_trainable_params("species_head") self._run_training_stage( "species_warmup", params, self.config["species_lr_warmup"], self.config["species_epochs_warmup"], start_epoch, freeze_backbone=True, ) # ── Sub-stage A2: Full fine-tune ── a2_start = self.config["species_epochs_warmup"] a2_end = a2_start + self.config["species_epochs_finetune"] if start_epoch <= a2_end: print("\n── Sub-stage A2: Full fine-tune (cosine LR) ──") params = self.model.get_trainable_params("species_full") self._run_training_stage( "species_finetune", params, self.config["species_lr_finetune"], a2_end, max(start_epoch, a2_start), freeze_backbone=False, cosine_schedule=True, warmup_epochs=self.config["warmup_epochs"], ) # ── Sub-stage A3: Final fine-tune ── a3_start = a2_end a3_end = a3_start + self.config["species_epochs_final"] if start_epoch <= a3_end: print("\n── Sub-stage A3: Final fine-tune (discriminative LR) ──") params = self.model.get_trainable_params("species_full") self._run_training_stage( "species_final", params, self.config["species_lr_final"], a3_end, max(start_epoch, a3_start), freeze_backbone=False, ) print("\n ✓ Stage A complete!") # ─── Stage B: Disease Classifiers ────────────────────────────────────── def train_disease_stage(self, resume_from=None): """ Stage B: Train disease-specific heads. Sub-stages: 1. All disease heads (backbone frozen) 2. Rare-class boost (oversampled rare classes) 3. End-to-end (unfrozen backbone, joint loss) """ print("\n" + "=" * 60) print("Stage B — Disease Classifier Training") print("=" * 60) if resume_from: self.load_checkpoint(resume_from) # ── Sub-stage B1: Disease heads ── print("\n── Sub-stage B1: Disease heads (backbone frozen) ──") params = self.model.get_trainable_params("disease_heads") self._run_training_stage( "disease_heads", params, self.config["disease_lr_heads"], self.config["disease_epochs_heads"], freeze_backbone=True, use_disease_loss=True, ) # ── Sub-stage B2: Rare-class boost ── print("\n── Sub-stage B2: Rare-class boost ──") params = self.model.get_trainable_params("disease_rare") self._run_training_stage( "disease_rare", params, self.config["disease_lr_rare"], self.config["disease_epochs_rare"], freeze_backbone=True, use_disease_loss=True, oversample_rare=True, ) # ── Sub-stage B3: End-to-end ── print("\n── Sub-stage B3: End-to-end joint training ──") params = self.model.get_trainable_params("e2e") self._run_training_stage( "disease_e2e", params, self.config["disease_lr_e2e"], self.config["disease_epochs_e2e"], freeze_backbone=False, use_disease_loss=True, ) print("\n ✓ Stage B complete!") # ─── Core Training Loop ──────────────────────────────────────────────── def _run_training_stage(self, stage_name, param_groups, base_lr, num_epochs, start_epoch=0, freeze_backbone=True, cosine_schedule=False, warmup_epochs=0, use_disease_loss=False, oversample_rare=False): """Run a training stage with the given parameters.""" # Freeze/unfreeze backbone for param in self.model.backbone.parameters(): param.requires_grad = not freeze_backbone # Setup optimizer self.optimizer = torch.optim.AdamW( param_groups, weight_decay=self.config["weight_decay"], ) # Setup scheduler if cosine_schedule: total_steps = num_epochs * len(self.train_loader) warmup_steps = warmup_epochs * len(self.train_loader) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=total_steps - warmup_steps ) # Wrap with linear warmup from torch.optim.lr_scheduler import SequentialLR, LinearLR warmup_scheduler = LinearLR( self.optimizer, start_factor=0.1, total_iters=warmup_steps ) self.scheduler = SequentialLR( self.optimizer, schedulers=[warmup_scheduler, self.scheduler], milestones=[warmup_steps], ) else: self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.5, patience=3 ) # Default dataloader train_loader = self.train_loader # Create oversampled dataloader for rare classes if requested if oversample_rare: # Add per-class weighting for rare disease classes weights = self._compute_disease_sample_weights() if weights is not None: sampler = WeightedRandomSampler( weights, len(weights), replacement=True ) train_loader = DataLoader( self.train_dataset, batch_size=self.config["species_batch_size"], sampler=sampler, num_workers=self.config["num_workers"], prefetch_factor=self.config["prefetch_factor"], pin_memory=True, persistent_workers=True, drop_last=True, ) print(f" Using weighted sampler for rare-class boost") # Training loop scaler = torch.amp.GradScaler(enabled=(self.config["precision"] == "16-mixed")) for epoch in range(start_epoch, num_epochs): # Train one epoch train_loss = self._train_epoch( train_loader, use_disease_loss, scaler, stage_name ) # Validate val_metrics = self._validate(use_disease_loss) # Step scheduler if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(val_metrics["val_loss"]) else: self.scheduler.step() # Log metrics = {"train_loss": train_loss, **val_metrics} self.log_metrics(stage_name, epoch, metrics) # Save checkpoint every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == num_epochs: ckpt_name = f"{stage_name}_epoch={epoch:02d}.pt" self.save_checkpoint( CHECKPOINT_DIR / stage_name / ckpt_name, stage_name, epoch, metrics ) def _train_epoch(self, loader, use_disease_loss, scaler, stage_name): """Train for one epoch.""" self.model.train() total_loss = 0 num_batches = 0 pbar = tqdm(loader, desc=f" {stage_name}", leave=False) for batch_idx, batch in enumerate(pbar): images = batch["image"].to(self.device) species_labels = batch["species_idx"].to(self.device) with torch.amp.autocast( device_type=self.device.type, enabled=(self.config["precision"] == "16-mixed") ): # Forward features, species_logits = self.model(images) loss_species = self.species_criterion(species_logits, species_labels) if use_disease_loss: disease_logits, _ = self.model.forward_disease( features, species_labels=species_labels ) disease_labels = batch["disease_idx"].to(self.device) loss_disease = self.disease_criterion(disease_logits, disease_labels) loss = loss_species + self.config["disease_loss_weight"] * loss_disease else: loss = loss_species # Backward self.optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config["gradient_clip_val"] ) scaler.step(self.optimizer) scaler.update() total_loss += loss.item() num_batches += 1 pbar.set_postfix({"loss": f"{loss.item():.4f}"}) return total_loss / num_batches @torch.no_grad() def _validate(self, use_disease_loss=False): """Run validation.""" self.model.eval() total_loss = 0 correct_species = 0 total_samples = 0 correct_disease = 0 for batch in tqdm(self.val_loader, desc=" val", leave=False): images = batch["image"].to(self.device) species_labels = batch["species_idx"].to(self.device) features, species_logits = self.model(images) # Species accuracy species_preds = species_logits.argmax(dim=1) correct_species += (species_preds == species_labels).sum().item() # Species loss loss_species = self.species_criterion(species_logits, species_labels) total_loss += loss_species.item() # Disease accuracy (if using disease heads) if use_disease_loss: disease_logits, _ = self.model.forward_disease( features, species_labels=species_labels ) disease_labels = batch["disease_idx"].to(self.device) # Pad disease labels to match logits width if disease_logits.shape[1] > 1: loss_disease = self.disease_criterion(disease_logits, disease_labels) total_loss += self.config["disease_loss_weight"] * loss_disease.item() disease_preds = disease_logits.argmax(dim=1) disease_labels_clipped = disease_labels.clamp(0, disease_logits.shape[1] - 1) correct_disease += (disease_preds == disease_labels_clipped).sum().item() total_samples += images.size(0) metrics = { "val_loss": total_loss / max(len(self.val_loader), 1), "val_species_acc": correct_species / max(total_samples, 1), } if use_disease_loss: metrics["val_disease_acc"] = correct_disease / max(total_samples, 1) return metrics def _compute_disease_sample_weights(self): """Compute sample weights for rare-class oversampling.""" weights = [] for species in self.train_dataset.species_list: class_weights = self.train_dataset.get_disease_class_weights(species) if class_weights is None: continue # Map weights to each sample for idx, (_, sp_idx, disease_name, sp) in enumerate(self.train_dataset.samples): if sp == species: w = class_weights.get(disease_name, 1.0) weights.append(w) return torch.tensor(weights, dtype=torch.float) if weights else None # ─── Main ──────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="Hierarchical Swin-Tiny Training") parser.add_argument("--stage", choices=["species", "disease", "all"], default="all", help="Training stage to run (default: all)") parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint path") parser.add_argument("--batch-size", type=int, default=None, help="Override batch size") parser.add_argument("--epochs", type=int, default=None, help="Override number of epochs (overrides all stage epochs)") parser.add_argument("--device", type=str, default=None, help="Override device (cuda, cpu)") parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging") parser.add_argument("--dry-run", action="store_true", help="Load data and model, run 1 batch, then exit") args = parser.parse_args() # Setup config config = DEFAULT_CONFIG.copy() if args.batch_size: config["species_batch_size"] = args.batch_size if args.device: config["device"] = args.device if args.epochs: for key in ["species_epochs_warmup", "species_epochs_finetune", "species_epochs_final", "disease_epochs_heads", "disease_epochs_rare", "disease_epochs_e2e"]: config[key] = args.epochs # Fix device name for ROCm if config["device"] == "cuda" and torch.cuda.is_available(): print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") if hasattr(torch.version, "hip") and torch.version.hip: print(f" ROCm version: {torch.version.hip}") # Create trainer trainer = HierarchicalTrainer(config) # Prepare data trainer.prepare_data() # Create or load model if args.resume: trainer.create_model() trainer.load_checkpoint(args.resume) else: trainer.create_model() # Dry run if args.dry_run: print("\n── Dry run: 1 batch ──") batch = next(iter(trainer.train_loader)) images = batch["image"].to(trainer.device) species_labels = batch["species_idx"].to(trainer.device) features, species_logits = trainer.model(images) loss = trainer.species_criterion(species_logits, species_labels) loss.backward() print(f" Forward + backward: ✓ (loss={loss.item():.4f})") print(f" Features shape: {features.shape}") print(f" Species logits shape: {species_logits.shape}") print(" Dry run complete!") return # Initialize wandb if HAS_WANDB and not args.no_wandb: wandb.init( project="plant-disease-hierarchical", config=config, ) # Run stages if args.stage in ("species", "all"): trainer.train_species_stage(resume_from=args.resume if args.stage == "species" else None) if args.stage in ("disease", "all"): # For 'all', continue from species checkpoint species_checkpoint = CHECKPOINT_DIR / "species_final" / "species_final_final.pt" trainer.train_disease_stage( resume_from=str(species_checkpoint) if species_checkpoint.exists() else None ) # Close wandb if HAS_WANDB and wandb.run is not None: wandb.finish() print("\n" + "=" * 60) print("Training complete!") print("=" * 60) if __name__ == "__main__": main()