#!/usr/bin/env python3 """ Phase 1 — Dataset Reorganization for Hierarchical Model Training. Reorganizes flat data/dataset/plant-disease-name/ directories into: data/organized/ train/{species}/{disease}/ val/{species}/{disease}/ species_index.json class_hierarchy.json dataset_stats.json Usage: python3 scripts/organize-dataset.py """ import json import os import random from collections import Counter, defaultdict from pathlib import Path from PIL import Image from joblib import Parallel, delayed from tqdm import tqdm # ─── Config ─────────────────────────────────────────────────────────────────── BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) DATASET_DIR = BASE_DIR / "data" / "dataset" ORGANIZED_DIR = BASE_DIR / "data" / "organized" TRAIN_DIR = ORGANIZED_DIR / "train" VAL_DIR = ORGANIZED_DIR / "val" RANDOM_SEED = 42 TRAIN_RATIO = 0.85 VAL_RATIO = 1.0 - TRAIN_RATIO MAX_DIM = 512 JPEG_QUALITY = 90 N_JOBS = 16 random.seed(RANDOM_SEED) # Known disease-prefix words — words that start disease names but should NOT # be part of a plant name. If a plant part ends with one of these, we know # the split point is wrong. DISEASE_PREFIX_WORDS = { "bacterial", "fungal", "viral", "downy", "powdery", "alternaria", "phytophthora", "phoma", "phymatotrichum", "pythium", "rhizoctonia", "sclerotinia", "fusarium", "verticillium", "cercospora", "septoria", "anthracnose", "black", "white", "gray", "brown", "green", "pink", "blue", "soft", "hard", "sour", "bitter", "southern", "northern", "common", "false", "true", "european", "american", "aspen", "bacterial-blight", "cercospora-leaf", "septoria-leaf", "alternaria-leaf", } # Valid multi-word plant suffixes (these CAN follow a hyphen in plant names) VALID_MULTI_WORD_PLANTS = { "squash", "bean", "berry", "apple", "fern", "tree", "vine", "cactus", "grass", "weed", "mint", "root", "seed", "leaf", "flower", "fruit", "bark", "wood", "nut", "pea", "lily", "rose", "moss", "palm", "fern", "orchid", "fig", "cress", "plant", "sage", "thyme", "leaf-fig", "nest-fern", "tongue", "tail", "ear", "eye", "nut-tree", "bean-tree", } IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"} # ─── Load KB Data ───────────────────────────────────────────────────────────── def load_kb(): with open(BASE_DIR / "src" / "data" / "plants.json") as f: plants = json.load(f) with open(BASE_DIR / "src" / "data" / "diseases.json") as f: diseases = json.load(f) return plants, diseases PLANTS, DISEASES = load_kb() KB_PLANT_IDS = {p["id"] for p in PLANTS} def get_dataset_dirs(): """Get all non-hidden subdirectories in the dataset folder.""" dirs = sorted([ d for d in os.listdir(DATASET_DIR) if os.path.isdir(DATASET_DIR / d) and not d.startswith(".") ]) return dirs def count_images(path): """Count image files in a directory.""" if not path.exists(): return 0 return len([ f for f in os.listdir(path) if os.path.isfile(path / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS ]) # ─── Phase 1: Parse directory names ──────────────────────────────────────────── def build_plant_and_disease_dictionaries(dirs): """ Build verified plant names and disease suffixes from the dataset. Returns (parsed_dict, unmatched_list). """ # Phase 1: Verify plant names from prefixes that appear with >=3 diseases plant_candidates = defaultdict(set) for d in dirs: parts = d.split("-") if len(parts) < 2: continue for split in range(1, min(len(parts), 6)): plant = "-".join(parts[:split]) disease = "-".join(parts[split:]) if plant and disease and len(disease) > 2: plant_candidates[plant].add(disease) verified_plants = set(KB_PLANT_IDS) for plant, diseases in plant_candidates.items(): if len(diseases) >= 3 and plant not in verified_plants: verified_plants.add(plant) print(f" Verified plants: {len(verified_plants)} ({len(verified_plants & KB_PLANT_IDS)} from KB)") # Phase 2: Match dirs by plant prefix (longest plant first) sorted_plants = sorted(verified_plants, key=len, reverse=True) plant_matched = {} not_matched = [] for d in dirs: matched = False for plant in sorted_plants: prefix = plant + "-" if d.startswith(prefix): disease = d[len(prefix):] if disease: plant_matched[d] = (plant, disease) matched = True break if not matched: if d.endswith("-healthy"): plant = d[:-len("-healthy")] plant_matched[d] = (plant, "healthy") else: not_matched.append(d) # Collect disease suffixes from Phase 2 matches disease_suffixes = set(p[1] for p in plant_matched.values()) print(f" Plant-matched dirs: {len(plant_matched)}, disease suffixes: {len(disease_suffixes)}") # Phase 3: Match remaining dirs by disease suffix (longest suffix first) sorted_disease_suffixes = sorted(disease_suffixes, key=len, reverse=True) still_not_matched = [] for d in not_matched: matched = False for suffix in sorted_disease_suffixes: if d.endswith("-" + suffix): plant_part = d[:-len("-" + suffix)] if plant_part and not plant_part.endswith("-"): plant_matched[d] = (plant_part, suffix) matched = True break if not matched: still_not_matched.append(d) print(f" Phase 3 matched: {len(not_matched) - len(still_not_matched)}") print(f" Phase 3 remaining: {len(still_not_matched)}") # Phase 4: Handle trailing-hyphen dirs and healthy parent dir final_unmatched = [] for d in still_not_matched: if d.endswith("-"): plant = d[:-1] if plant: plant_matched[d] = (plant, "unlabeled") elif d == "healthy": healthy_dir = DATASET_DIR / "healthy" if healthy_dir.exists(): plant_subdirs = [ s for s in os.listdir(healthy_dir) if os.path.isdir(healthy_dir / s) and not s.startswith(".") ] for sub_plant in plant_subdirs: # Use healthy/{sub_plant} as key so we know where to find the images plant_matched[f"healthy/{sub_plant}"] = (sub_plant, "healthy") print(f" Healthy dir: {len(plant_subdirs)} per-plant healthy classes") else: final_unmatched.append(d) print(f" Phase 4 handled {len(still_not_matched) - len(final_unmatched)} edge cases") print(f" Final unmatched: {len(final_unmatched)}") if final_unmatched: print(f" E.g.: {final_unmatched[:10]}") # Phase 5: Post-processing — fix species names that ate disease-prefix words fix_count = 0 for d in list(plant_matched.keys()): if d.startswith("healthy/"): continue # Skip healthy subdirs — these are correct species, disease = plant_matched[d] parts = species.split("-") if len(parts) >= 2 and parts[-1] in DISEASE_PREFIX_WORDS: # Move the last word from species to disease new_species = "-".join(parts[:-1]) new_disease = parts[-1] + "-" + disease plant_matched[d] = (new_species, new_disease) fix_count += 1 print(f" Post-process fixes (species ending with disease-prefix): {fix_count}") return plant_matched, final_unmatched # ─── Image Processing ──────────────────────────────────────────────────────── def process_image(args): """Resize and convert a single image to 512px max JPEG q90.""" src_path, dst_path = args try: img = Image.open(src_path) if img.mode != "RGB": img = img.convert("RGB") w, h = img.size if max(w, h) > MAX_DIM: ratio = MAX_DIM / max(w, h) img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS) os.makedirs(os.path.dirname(dst_path), exist_ok=True) img.save(dst_path, "JPEG", quality=JPEG_QUALITY, optimize=True) return (src_path, True, None) except Exception as e: return (src_path, False, str(e)) def copy_and_split_class(src_dir, dst_train_dir, dst_val_dir, train_ratio=TRAIN_RATIO): """ Copy images from src_dir to train/val dirs, splitting at the IMAGE level. Returns (train_processed, train_failed, val_processed, val_failed). """ # Check both possible source paths (regular dir or healthy subdir) if not src_dir.exists(): return (0, 0, 0, 0) src_files = sorted([ f for f in os.listdir(src_dir) if os.path.isfile(src_dir / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS ]) if not src_files: return (0, 0, 0, 0) # Split files at IMAGE level random.shuffle(src_files) split_idx = max(1, int(len(src_files) * train_ratio)) train_files = src_files[:split_idx] val_files = src_files[split_idx:] # Process train images train_pairs = [ (str(src_dir / f), str(dst_train_dir / f"img_{i:04d}.jpg")) for i, f in enumerate(train_files) ] val_pairs = [ (str(src_dir / f), str(dst_val_dir / f"img_{i:04d}.jpg")) for i, f in enumerate(val_files) ] results = Parallel(n_jobs=N_JOBS, prefer="threads")( delayed(process_image)(pair) for pair in train_pairs + val_pairs ) train_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i < len(train_pairs)) train_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i < len(train_pairs)) val_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i >= len(train_pairs)) val_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i >= len(train_pairs)) return (train_ok, train_fail, val_ok, val_fail) # ─── Build Metadata ────────────────────────────────────────────────────────── def build_metadata(parsed, train_counts, val_counts, unmatched): """Build species_index.json, class_hierarchy.json, dataset_stats.json.""" species_disease_map = defaultdict(set) for species, disease in parsed.values(): species_disease_map[species].add(disease) species_index = {k: sorted(v) for k, v in sorted(species_disease_map.items())} class_hierarchy = { "version": "1.0", "description": "Hierarchical plant disease classification dataset", "num_species": len(species_index), "num_classes": len(parsed), "species": {species: sorted(diseases) for species, diseases in species_index.items()} } # Aggregate counts total_train = sum(cnt for sp, di, cnt in train_counts) total_val = sum(cnt for sp, di, cnt in val_counts) total_all = total_train + total_val all_counts = [cnt for _, _, cnt in (train_counts + val_counts)] species_disease_counts = defaultdict(lambda: defaultdict(int)) for sp, di, cnt in train_counts + val_counts: species_disease_counts[sp][di] += cnt # Also count classes from the parsed dict (unique species/disease combos) parsed_classes = set((sp, di) for sp, di in parsed.values()) stats = { "total_images": total_all, "total_species": len(species_index), "total_classes": len(parsed_classes), "train_images": total_train, "val_images": total_val, "images_per_class": { "min": min(all_counts) if all_counts else 0, "max": max(all_counts) if all_counts else 0, "mean": round(sum(all_counts) / len(all_counts)) if all_counts else 0, "median": sorted(all_counts)[len(all_counts) // 2] if all_counts else 0, }, "train_pct": round(total_train / total_all * 100, 1) if total_all else 0, "val_pct": round(total_val / total_all * 100, 1) if total_all else 0, "unmatched_dirs": len(unmatched), "unmatched_dir_names": unmatched[:100] if unmatched else [], "species_disease_counts": { species: dict(diseases) for species, diseases in species_disease_counts.items() } } return species_index, class_hierarchy, stats # ─── Main Pipeline ─────────────────────────────────────────────────────────── def main(): print("=" * 60) print("Phase 1 — Dataset Reorganization") print("=" * 60) print(f"Dataset: {DATASET_DIR}") print(f"Output: {ORGANIZED_DIR}") print() # Step 1: Scan print("─" * 40) print("Step 1: Scanning dataset directories...") print("─" * 40) dirs = get_dataset_dirs() print(f" Found {len(dirs)} class directories") # Step 2: Parse directory names into (species, disease) pairs print() print("─" * 40) print("Step 2: Parsing directory names...") print("─" * 40) parsed, unmatched = build_plant_and_disease_dictionaries(dirs) species_set = set(s for s, _ in parsed.values()) disease_set = set(d for _, d in parsed.values()) raw_classes = len(parsed) unique_classes = len(set((s, d) for s, d in parsed.values())) print(f"\n Parsed: {raw_classes} entries") print(f" Unique species: {len(species_set)}") print(f" Unique disease labels: {len(disease_set)}") print(f" Unique (species, disease) pairs: {unique_classes}") # Step 3: Process images with image-level train/val split print() print("─" * 40) print("Step 3: Processing images (resize + train/val split)...") print(f" Max dimension: {MAX_DIM}px, JPEG q{JPEG_QUALITY}") print(f" Workers: {N_JOBS}") print(f" Split: {TRAIN_RATIO*100:.0f}/{VAL_RATIO*100:.0f} (image-level)") print("─" * 40) train_counts = [] # (species, disease, count) val_counts = [] total_skipped = 0 # Process regular dirs regular_items = [(d, sp, di) for d, (sp, di) in parsed.items() if not d.startswith("healthy/") and d in dirs] healthy_items = [(d, sp, di) for d, (sp, di) in parsed.items() if d.startswith("healthy/")] # Organize healthy items by plant healthy_by_plant = {} for d, sp, di in healthy_items: healthy_by_plant[sp] = d # d is like "healthy/tomato" print(f"\n Processing {len(regular_items)} disease + {len(healthy_items)} healthy classes...") for d, species, disease in tqdm(regular_items, desc=" Disease classes"): src_dir = DATASET_DIR / d dst_train = TRAIN_DIR / species / disease dst_val = VAL_DIR / species / disease # Skip if already done (check a few files) if dst_train.exists() and dst_val.exists() and \ len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir): total_skipped += count_images(src_dir) continue tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val) train_counts.append((species, disease, tr_ok)) val_counts.append((species, disease, va_ok)) # Process healthy subdirs for sp, hkey in tqdm(healthy_by_plant.items(), desc=" Healthy classes"): src_dir = DATASET_DIR / hkey # e.g. data/dataset/healthy/tomato dst_train = TRAIN_DIR / sp / "healthy" dst_val = VAL_DIR / sp / "healthy" if dst_train.exists() and dst_val.exists() and \ len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir): total_skipped += count_images(src_dir) continue tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val) train_counts.append((sp, "healthy", tr_ok)) val_counts.append((sp, "healthy", va_ok)) total_train = sum(c for _, _, c in train_counts) total_val = sum(c for _, _, c in val_counts) print(f"\n Train images: {total_train:,}") print(f" Val images: {total_val:,}") print(f" Skipped previously processed: {total_skipped:,}") # Step 4: Build metadata print() print("─" * 40) print("Step 4: Building metadata files...") print("─" * 40) ORGANIZED_DIR.mkdir(parents=True, exist_ok=True) species_index, class_hierarchy, stats = build_metadata( parsed, train_counts, val_counts, unmatched ) with open(ORGANIZED_DIR / "species_index.json", "w") as f: json.dump(species_index, f, indent=2) print(f" ✓ species_index.json ({len(species_index)} species)") with open(ORGANIZED_DIR / "class_hierarchy.json", "w") as f: json.dump(class_hierarchy, f, indent=2) print(f" ✓ class_hierarchy.json") with open(ORGANIZED_DIR / "dataset_stats.json", "w") as f: json.dump(stats, f, indent=2) print(f" ✓ dataset_stats.json") # Summary print() print("=" * 60) print("Done!") print("=" * 60) print(f" Total images: {stats['total_images']:,}") print(f" Species: {stats['total_species']}") print(f" Classes: {stats['total_classes']}") print(f" Train: {stats['train_images']:,} ({stats['train_pct']}%)") print(f" Val: {stats['val_images']:,} ({stats['val_pct']}%)") print(f" Unmatched dirs: {stats['unmatched_dirs']}") print(f" Train dir: {TRAIN_DIR}") print(f" Val dir: {VAL_DIR}") if stats['unmatched_dirs'] > 0: print(f"\n ⚠ Manual review needed for {stats['unmatched_dirs']} dirs:") for u in stats['unmatched_dir_names'][:20]: print(f" {u}") return stats if __name__ == "__main__": main()