pre torch.compile -chkpoint made
This commit is contained in:
471
scripts/organize-dataset.py
Normal file
471
scripts/organize-dataset.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Phase 1 — Dataset Reorganization for Hierarchical Model Training.
|
||||
|
||||
Reorganizes flat data/dataset/plant-disease-name/ directories into:
|
||||
data/organized/
|
||||
train/{species}/{disease}/
|
||||
val/{species}/{disease}/
|
||||
species_index.json
|
||||
class_hierarchy.json
|
||||
dataset_stats.json
|
||||
|
||||
Usage: python3 scripts/organize-dataset.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from joblib import Parallel, delayed
|
||||
from tqdm import tqdm
|
||||
|
||||
# ─── Config ───────────────────────────────────────────────────────────────────
|
||||
|
||||
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
DATASET_DIR = BASE_DIR / "data" / "dataset"
|
||||
ORGANIZED_DIR = BASE_DIR / "data" / "organized"
|
||||
TRAIN_DIR = ORGANIZED_DIR / "train"
|
||||
VAL_DIR = ORGANIZED_DIR / "val"
|
||||
|
||||
RANDOM_SEED = 42
|
||||
TRAIN_RATIO = 0.85
|
||||
VAL_RATIO = 1.0 - TRAIN_RATIO
|
||||
|
||||
MAX_DIM = 512
|
||||
JPEG_QUALITY = 90
|
||||
N_JOBS = 16
|
||||
|
||||
random.seed(RANDOM_SEED)
|
||||
|
||||
# Known disease-prefix words — words that start disease names but should NOT
|
||||
# be part of a plant name. If a plant part ends with one of these, we know
|
||||
# the split point is wrong.
|
||||
DISEASE_PREFIX_WORDS = {
|
||||
"bacterial", "fungal", "viral", "downy", "powdery",
|
||||
"alternaria", "phytophthora", "phoma", "phymatotrichum",
|
||||
"pythium", "rhizoctonia", "sclerotinia", "fusarium",
|
||||
"verticillium", "cercospora", "septoria", "anthracnose",
|
||||
"black", "white", "gray", "brown", "green", "pink", "blue",
|
||||
"soft", "hard", "sour", "bitter",
|
||||
"southern", "northern", "common", "false", "true",
|
||||
"european", "american", "aspen", "bacterial-blight",
|
||||
"cercospora-leaf", "septoria-leaf", "alternaria-leaf",
|
||||
}
|
||||
|
||||
# Valid multi-word plant suffixes (these CAN follow a hyphen in plant names)
|
||||
VALID_MULTI_WORD_PLANTS = {
|
||||
"squash", "bean", "berry", "apple", "fern", "tree", "vine",
|
||||
"cactus", "grass", "weed", "mint", "root", "seed", "leaf",
|
||||
"flower", "fruit", "bark", "wood", "nut", "pea", "lily",
|
||||
"rose", "moss", "palm", "fern", "orchid", "fig", "cress",
|
||||
"plant", "sage", "thyme", "leaf-fig", "nest-fern", "tongue",
|
||||
"tail", "ear", "eye", "nut-tree", "bean-tree",
|
||||
}
|
||||
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"}
|
||||
|
||||
# ─── Load KB Data ─────────────────────────────────────────────────────────────
|
||||
|
||||
def load_kb():
|
||||
with open(BASE_DIR / "src" / "data" / "plants.json") as f:
|
||||
plants = json.load(f)
|
||||
with open(BASE_DIR / "src" / "data" / "diseases.json") as f:
|
||||
diseases = json.load(f)
|
||||
return plants, diseases
|
||||
|
||||
PLANTS, DISEASES = load_kb()
|
||||
KB_PLANT_IDS = {p["id"] for p in PLANTS}
|
||||
|
||||
def get_dataset_dirs():
|
||||
"""Get all non-hidden subdirectories in the dataset folder."""
|
||||
dirs = sorted([
|
||||
d for d in os.listdir(DATASET_DIR)
|
||||
if os.path.isdir(DATASET_DIR / d) and not d.startswith(".")
|
||||
])
|
||||
return dirs
|
||||
|
||||
def count_images(path):
|
||||
"""Count image files in a directory."""
|
||||
if not path.exists():
|
||||
return 0
|
||||
return len([
|
||||
f for f in os.listdir(path)
|
||||
if os.path.isfile(path / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS
|
||||
])
|
||||
|
||||
# ─── Phase 1: Parse directory names ────────────────────────────────────────────
|
||||
|
||||
def build_plant_and_disease_dictionaries(dirs):
|
||||
"""
|
||||
Build verified plant names and disease suffixes from the dataset.
|
||||
Returns (parsed_dict, unmatched_list).
|
||||
"""
|
||||
# Phase 1: Verify plant names from prefixes that appear with >=3 diseases
|
||||
plant_candidates = defaultdict(set)
|
||||
for d in dirs:
|
||||
parts = d.split("-")
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
for split in range(1, min(len(parts), 6)):
|
||||
plant = "-".join(parts[:split])
|
||||
disease = "-".join(parts[split:])
|
||||
if plant and disease and len(disease) > 2:
|
||||
plant_candidates[plant].add(disease)
|
||||
|
||||
verified_plants = set(KB_PLANT_IDS)
|
||||
for plant, diseases in plant_candidates.items():
|
||||
if len(diseases) >= 3 and plant not in verified_plants:
|
||||
verified_plants.add(plant)
|
||||
|
||||
print(f" Verified plants: {len(verified_plants)} ({len(verified_plants & KB_PLANT_IDS)} from KB)")
|
||||
|
||||
# Phase 2: Match dirs by plant prefix (longest plant first)
|
||||
sorted_plants = sorted(verified_plants, key=len, reverse=True)
|
||||
plant_matched = {}
|
||||
not_matched = []
|
||||
|
||||
for d in dirs:
|
||||
matched = False
|
||||
for plant in sorted_plants:
|
||||
prefix = plant + "-"
|
||||
if d.startswith(prefix):
|
||||
disease = d[len(prefix):]
|
||||
if disease:
|
||||
plant_matched[d] = (plant, disease)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
if d.endswith("-healthy"):
|
||||
plant = d[:-len("-healthy")]
|
||||
plant_matched[d] = (plant, "healthy")
|
||||
else:
|
||||
not_matched.append(d)
|
||||
|
||||
# Collect disease suffixes from Phase 2 matches
|
||||
disease_suffixes = set(p[1] for p in plant_matched.values())
|
||||
print(f" Plant-matched dirs: {len(plant_matched)}, disease suffixes: {len(disease_suffixes)}")
|
||||
|
||||
# Phase 3: Match remaining dirs by disease suffix (longest suffix first)
|
||||
sorted_disease_suffixes = sorted(disease_suffixes, key=len, reverse=True)
|
||||
still_not_matched = []
|
||||
|
||||
for d in not_matched:
|
||||
matched = False
|
||||
for suffix in sorted_disease_suffixes:
|
||||
if d.endswith("-" + suffix):
|
||||
plant_part = d[:-len("-" + suffix)]
|
||||
if plant_part and not plant_part.endswith("-"):
|
||||
plant_matched[d] = (plant_part, suffix)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
still_not_matched.append(d)
|
||||
|
||||
print(f" Phase 3 matched: {len(not_matched) - len(still_not_matched)}")
|
||||
print(f" Phase 3 remaining: {len(still_not_matched)}")
|
||||
|
||||
# Phase 4: Handle trailing-hyphen dirs and healthy parent dir
|
||||
final_unmatched = []
|
||||
for d in still_not_matched:
|
||||
if d.endswith("-"):
|
||||
plant = d[:-1]
|
||||
if plant:
|
||||
plant_matched[d] = (plant, "unlabeled")
|
||||
elif d == "healthy":
|
||||
healthy_dir = DATASET_DIR / "healthy"
|
||||
if healthy_dir.exists():
|
||||
plant_subdirs = [
|
||||
s for s in os.listdir(healthy_dir)
|
||||
if os.path.isdir(healthy_dir / s) and not s.startswith(".")
|
||||
]
|
||||
for sub_plant in plant_subdirs:
|
||||
# Use healthy/{sub_plant} as key so we know where to find the images
|
||||
plant_matched[f"healthy/{sub_plant}"] = (sub_plant, "healthy")
|
||||
print(f" Healthy dir: {len(plant_subdirs)} per-plant healthy classes")
|
||||
else:
|
||||
final_unmatched.append(d)
|
||||
|
||||
print(f" Phase 4 handled {len(still_not_matched) - len(final_unmatched)} edge cases")
|
||||
print(f" Final unmatched: {len(final_unmatched)}")
|
||||
if final_unmatched:
|
||||
print(f" E.g.: {final_unmatched[:10]}")
|
||||
|
||||
# Phase 5: Post-processing — fix species names that ate disease-prefix words
|
||||
fix_count = 0
|
||||
for d in list(plant_matched.keys()):
|
||||
if d.startswith("healthy/"):
|
||||
continue # Skip healthy subdirs — these are correct
|
||||
species, disease = plant_matched[d]
|
||||
parts = species.split("-")
|
||||
if len(parts) >= 2 and parts[-1] in DISEASE_PREFIX_WORDS:
|
||||
# Move the last word from species to disease
|
||||
new_species = "-".join(parts[:-1])
|
||||
new_disease = parts[-1] + "-" + disease
|
||||
plant_matched[d] = (new_species, new_disease)
|
||||
fix_count += 1
|
||||
|
||||
print(f" Post-process fixes (species ending with disease-prefix): {fix_count}")
|
||||
|
||||
return plant_matched, final_unmatched
|
||||
|
||||
# ─── Image Processing ────────────────────────────────────────────────────────
|
||||
|
||||
def process_image(args):
|
||||
"""Resize and convert a single image to 512px max JPEG q90."""
|
||||
src_path, dst_path = args
|
||||
try:
|
||||
img = Image.open(src_path)
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
w, h = img.size
|
||||
if max(w, h) > MAX_DIM:
|
||||
ratio = MAX_DIM / max(w, h)
|
||||
img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
|
||||
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
||||
img.save(dst_path, "JPEG", quality=JPEG_QUALITY, optimize=True)
|
||||
return (src_path, True, None)
|
||||
except Exception as e:
|
||||
return (src_path, False, str(e))
|
||||
|
||||
def copy_and_split_class(src_dir, dst_train_dir, dst_val_dir, train_ratio=TRAIN_RATIO):
|
||||
"""
|
||||
Copy images from src_dir to train/val dirs, splitting at the IMAGE level.
|
||||
Returns (train_processed, train_failed, val_processed, val_failed).
|
||||
"""
|
||||
# Check both possible source paths (regular dir or healthy subdir)
|
||||
if not src_dir.exists():
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
src_files = sorted([
|
||||
f for f in os.listdir(src_dir)
|
||||
if os.path.isfile(src_dir / f) and os.path.splitext(f)[1].lower() in IMAGE_EXTS
|
||||
])
|
||||
if not src_files:
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
# Split files at IMAGE level
|
||||
random.shuffle(src_files)
|
||||
split_idx = max(1, int(len(src_files) * train_ratio))
|
||||
train_files = src_files[:split_idx]
|
||||
val_files = src_files[split_idx:]
|
||||
|
||||
# Process train images
|
||||
train_pairs = [
|
||||
(str(src_dir / f), str(dst_train_dir / f"img_{i:04d}.jpg"))
|
||||
for i, f in enumerate(train_files)
|
||||
]
|
||||
val_pairs = [
|
||||
(str(src_dir / f), str(dst_val_dir / f"img_{i:04d}.jpg"))
|
||||
for i, f in enumerate(val_files)
|
||||
]
|
||||
|
||||
results = Parallel(n_jobs=N_JOBS, prefer="threads")(
|
||||
delayed(process_image)(pair) for pair in train_pairs + val_pairs
|
||||
)
|
||||
|
||||
train_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i < len(train_pairs))
|
||||
train_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i < len(train_pairs))
|
||||
val_ok = sum(1 for i, (_, ok, _) in enumerate(results) if ok and i >= len(train_pairs))
|
||||
val_fail = sum(1 for i, (_, ok, _) in enumerate(results) if not ok and i >= len(train_pairs))
|
||||
|
||||
return (train_ok, train_fail, val_ok, val_fail)
|
||||
|
||||
# ─── Build Metadata ──────────────────────────────────────────────────────────
|
||||
|
||||
def build_metadata(parsed, train_counts, val_counts, unmatched):
|
||||
"""Build species_index.json, class_hierarchy.json, dataset_stats.json."""
|
||||
species_disease_map = defaultdict(set)
|
||||
for species, disease in parsed.values():
|
||||
species_disease_map[species].add(disease)
|
||||
species_index = {k: sorted(v) for k, v in sorted(species_disease_map.items())}
|
||||
|
||||
class_hierarchy = {
|
||||
"version": "1.0",
|
||||
"description": "Hierarchical plant disease classification dataset",
|
||||
"num_species": len(species_index),
|
||||
"num_classes": len(parsed),
|
||||
"species": {species: sorted(diseases) for species, diseases in species_index.items()}
|
||||
}
|
||||
|
||||
# Aggregate counts
|
||||
total_train = sum(cnt for sp, di, cnt in train_counts)
|
||||
total_val = sum(cnt for sp, di, cnt in val_counts)
|
||||
total_all = total_train + total_val
|
||||
|
||||
all_counts = [cnt for _, _, cnt in (train_counts + val_counts)]
|
||||
|
||||
species_disease_counts = defaultdict(lambda: defaultdict(int))
|
||||
for sp, di, cnt in train_counts + val_counts:
|
||||
species_disease_counts[sp][di] += cnt
|
||||
|
||||
# Also count classes from the parsed dict (unique species/disease combos)
|
||||
parsed_classes = set((sp, di) for sp, di in parsed.values())
|
||||
|
||||
stats = {
|
||||
"total_images": total_all,
|
||||
"total_species": len(species_index),
|
||||
"total_classes": len(parsed_classes),
|
||||
"train_images": total_train,
|
||||
"val_images": total_val,
|
||||
"images_per_class": {
|
||||
"min": min(all_counts) if all_counts else 0,
|
||||
"max": max(all_counts) if all_counts else 0,
|
||||
"mean": round(sum(all_counts) / len(all_counts)) if all_counts else 0,
|
||||
"median": sorted(all_counts)[len(all_counts) // 2] if all_counts else 0,
|
||||
},
|
||||
"train_pct": round(total_train / total_all * 100, 1) if total_all else 0,
|
||||
"val_pct": round(total_val / total_all * 100, 1) if total_all else 0,
|
||||
"unmatched_dirs": len(unmatched),
|
||||
"unmatched_dir_names": unmatched[:100] if unmatched else [],
|
||||
"species_disease_counts": {
|
||||
species: dict(diseases) for species, diseases in species_disease_counts.items()
|
||||
}
|
||||
}
|
||||
|
||||
return species_index, class_hierarchy, stats
|
||||
|
||||
# ─── Main Pipeline ───────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Phase 1 — Dataset Reorganization")
|
||||
print("=" * 60)
|
||||
print(f"Dataset: {DATASET_DIR}")
|
||||
print(f"Output: {ORGANIZED_DIR}")
|
||||
print()
|
||||
|
||||
# Step 1: Scan
|
||||
print("─" * 40)
|
||||
print("Step 1: Scanning dataset directories...")
|
||||
print("─" * 40)
|
||||
dirs = get_dataset_dirs()
|
||||
print(f" Found {len(dirs)} class directories")
|
||||
|
||||
# Step 2: Parse directory names into (species, disease) pairs
|
||||
print()
|
||||
print("─" * 40)
|
||||
print("Step 2: Parsing directory names...")
|
||||
print("─" * 40)
|
||||
parsed, unmatched = build_plant_and_disease_dictionaries(dirs)
|
||||
|
||||
species_set = set(s for s, _ in parsed.values())
|
||||
disease_set = set(d for _, d in parsed.values())
|
||||
raw_classes = len(parsed)
|
||||
unique_classes = len(set((s, d) for s, d in parsed.values()))
|
||||
print(f"\n Parsed: {raw_classes} entries")
|
||||
print(f" Unique species: {len(species_set)}")
|
||||
print(f" Unique disease labels: {len(disease_set)}")
|
||||
print(f" Unique (species, disease) pairs: {unique_classes}")
|
||||
|
||||
# Step 3: Process images with image-level train/val split
|
||||
print()
|
||||
print("─" * 40)
|
||||
print("Step 3: Processing images (resize + train/val split)...")
|
||||
print(f" Max dimension: {MAX_DIM}px, JPEG q{JPEG_QUALITY}")
|
||||
print(f" Workers: {N_JOBS}")
|
||||
print(f" Split: {TRAIN_RATIO*100:.0f}/{VAL_RATIO*100:.0f} (image-level)")
|
||||
print("─" * 40)
|
||||
|
||||
train_counts = [] # (species, disease, count)
|
||||
val_counts = []
|
||||
total_skipped = 0
|
||||
|
||||
# Process regular dirs
|
||||
regular_items = [(d, sp, di) for d, (sp, di) in parsed.items()
|
||||
if not d.startswith("healthy/") and d in dirs]
|
||||
healthy_items = [(d, sp, di) for d, (sp, di) in parsed.items()
|
||||
if d.startswith("healthy/")]
|
||||
|
||||
# Organize healthy items by plant
|
||||
healthy_by_plant = {}
|
||||
for d, sp, di in healthy_items:
|
||||
healthy_by_plant[sp] = d # d is like "healthy/tomato"
|
||||
|
||||
print(f"\n Processing {len(regular_items)} disease + {len(healthy_items)} healthy classes...")
|
||||
|
||||
for d, species, disease in tqdm(regular_items, desc=" Disease classes"):
|
||||
src_dir = DATASET_DIR / d
|
||||
dst_train = TRAIN_DIR / species / disease
|
||||
dst_val = VAL_DIR / species / disease
|
||||
|
||||
# Skip if already done (check a few files)
|
||||
if dst_train.exists() and dst_val.exists() and \
|
||||
len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir):
|
||||
total_skipped += count_images(src_dir)
|
||||
continue
|
||||
|
||||
tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val)
|
||||
train_counts.append((species, disease, tr_ok))
|
||||
val_counts.append((species, disease, va_ok))
|
||||
|
||||
# Process healthy subdirs
|
||||
for sp, hkey in tqdm(healthy_by_plant.items(), desc=" Healthy classes"):
|
||||
src_dir = DATASET_DIR / hkey # e.g. data/dataset/healthy/tomato
|
||||
dst_train = TRAIN_DIR / sp / "healthy"
|
||||
dst_val = VAL_DIR / sp / "healthy"
|
||||
|
||||
if dst_train.exists() and dst_val.exists() and \
|
||||
len(os.listdir(dst_train)) + len(os.listdir(dst_val)) >= count_images(src_dir):
|
||||
total_skipped += count_images(src_dir)
|
||||
continue
|
||||
|
||||
tr_ok, tr_fail, va_ok, va_fail = copy_and_split_class(src_dir, dst_train, dst_val)
|
||||
train_counts.append((sp, "healthy", tr_ok))
|
||||
val_counts.append((sp, "healthy", va_ok))
|
||||
|
||||
total_train = sum(c for _, _, c in train_counts)
|
||||
total_val = sum(c for _, _, c in val_counts)
|
||||
print(f"\n Train images: {total_train:,}")
|
||||
print(f" Val images: {total_val:,}")
|
||||
print(f" Skipped previously processed: {total_skipped:,}")
|
||||
|
||||
# Step 4: Build metadata
|
||||
print()
|
||||
print("─" * 40)
|
||||
print("Step 4: Building metadata files...")
|
||||
print("─" * 40)
|
||||
ORGANIZED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
species_index, class_hierarchy, stats = build_metadata(
|
||||
parsed, train_counts, val_counts, unmatched
|
||||
)
|
||||
|
||||
with open(ORGANIZED_DIR / "species_index.json", "w") as f:
|
||||
json.dump(species_index, f, indent=2)
|
||||
print(f" ✓ species_index.json ({len(species_index)} species)")
|
||||
|
||||
with open(ORGANIZED_DIR / "class_hierarchy.json", "w") as f:
|
||||
json.dump(class_hierarchy, f, indent=2)
|
||||
print(f" ✓ class_hierarchy.json")
|
||||
|
||||
with open(ORGANIZED_DIR / "dataset_stats.json", "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
print(f" ✓ dataset_stats.json")
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Done!")
|
||||
print("=" * 60)
|
||||
print(f" Total images: {stats['total_images']:,}")
|
||||
print(f" Species: {stats['total_species']}")
|
||||
print(f" Classes: {stats['total_classes']}")
|
||||
print(f" Train: {stats['train_images']:,} ({stats['train_pct']}%)")
|
||||
print(f" Val: {stats['val_images']:,} ({stats['val_pct']}%)")
|
||||
print(f" Unmatched dirs: {stats['unmatched_dirs']}")
|
||||
print(f" Train dir: {TRAIN_DIR}")
|
||||
print(f" Val dir: {VAL_DIR}")
|
||||
|
||||
if stats['unmatched_dirs'] > 0:
|
||||
print(f"\n ⚠ Manual review needed for {stats['unmatched_dirs']} dirs:")
|
||||
for u in stats['unmatched_dir_names'][:20]:
|
||||
print(f" {u}")
|
||||
|
||||
return stats
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user