# Phase 3 — ONNX Export & Quantization **Blocked by**: Phase 2 (trained model) **Blocks**: Phase 4 (server inference) **Est. time**: 1-2 days **Machine**: Any (RTX 3090 recommended for ONNX GPU validation) ## Objective Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment. ## Deliverables ``` public/models/ ├── swin-species.onnx # FP16 species model (3.2 MB) ├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB) ├── disease-heads/ # One ONNX per species │ ├── tomato-int8.onnx │ ├── acorn-squash-int8.onnx │ └── ... ├── disease-heads-list.json # Maps species ID → ONNX file path ├── ood-detector.pkl # Mahalanobis parameters for OOD └── onnx-metadata.json # Input/output shapes, versions ``` ## Steps ### 3.1 Export backbone + species head as single ONNX ```python import torch import onnx from pathlib import Path model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt") model.eval() # Export end-to-end species model (backbone + species head) dummy = torch.randn(1, 3, 224, 224) torch.onnx.export( model, # Combined forward: backbone + species_head dummy, "public/models/swin-species.onnx", input_names=["input"], output_names=["species_logits", "embedding"], dynamic_axes={ "input": {0: "batch_size"}, "species_logits": {0: "batch_size"}, "embedding": {0: "batch_size"}, }, opset_version=17, do_constant_folding=True, ) ``` **Key**: Export the 768-dim `embedding` as a second output — the server needs it to route to the correct disease head. ### 3.2 Export disease heads individually Each disease head is a simple `nn.Linear(768, N_diseases)`. Export as a mini-ONNX that takes the embedding and returns disease logits: ```python for species_name, head in model.disease_heads.items(): dummy_embed = torch.randn(1, 768) torch.onnx.export( head, dummy_embed, f"public/models/disease-heads/{species_name}.onnx", input_names=["embedding"], output_names=["disease_logits"], dynamic_axes={"embedding": {0: "batch_size"}}, opset_version=17, ) ``` **Total**: ~320 small ONNX files, each ~50-200 KB. ### 3.3 INT8 Quantization Use ONNX Runtime's quantization tooling: ```python from onnxruntime.quantization import quantize_dynamic, QuantType # Quantize species model quantize_dynamic( "public/models/swin-species.onnx", "public/models/swin-species-int8.onnx", weight_type=QuantType.QInt8, ) # Quantize disease heads (batch) for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")): quantize_dynamic( str(onnx_path), str(onnx_path.with_suffix("-int8.onnx")), weight_type=QuantType.QInt8, ) ``` **Accuracy impact**: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers. ### 3.4 OOD Detector Train a Mahalanobis distance-based OOD detector on the training set embeddings: ```python import numpy as np from scipy.spatial.distance import mahalanobis # Collect embeddings from training set embeddings = [] for batch in val_dataloader: with torch.no_grad(): _, emb = model(batch["image"]) embeddings.append(emb.numpy()) embeddings = np.vstack(embeddings) # Fit multivariate Gaussian mean = np.mean(embeddings, axis=0) cov = np.cov(embeddings, rowvar=False) inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0])) # Save for inference import pickle with open("public/models/ood-detector.pkl", "wb") as f: pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f) ``` The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD. ### 3.5 Accuracy verification Before committing to ONNX, verify against PyTorch: ```python import onnxruntime as ort # Compare PyTorch vs ONNX outputs pytorch_out = model(sample_image) ort_out = ort.InferenceSession("swin-species-int8.onnx").run( ["species_logits"], {"input": sample_image.numpy()} ) max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0])) assert max_diff < 0.01, f"ONNX mismatch: {max_diff}" ``` ## Edge Cases & Gotchas - **ONNX opset compatibility**: Some `timm` model ops (like `roll` in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19. - **Dynamic axes**: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions. - **Disease head routing**: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly. - **Strix Halo ROCm + ONNX**: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both. - **Disease head file count**: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference). ## Verification - [ ] ONNX species model output matches PyTorch output (max diff < 0.01) - [ ] INT8 accuracy within 1% of FP16 on val set (sample 10K images) - [ ] ONNX model loads in ONNX Runtime without errors - [ ] All 320+ disease heads export successfully - [ ] OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision - [ ] ONNX model size < 5MB (INT8) for species, < 200KB per disease head - [ ] Inference on CPU (Strix Halo) < 200ms for species + disease combined