5.7 KiB
Phase 3 — ONNX Export & Quantization
Blocked by: Phase 2 (trained model) Blocks: Phase 4 (server inference) Est. time: 1-2 days Machine: Any (RTX 3090 recommended for ONNX GPU validation)
Objective
Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment.
Deliverables
public/models/
├── swin-species.onnx # FP16 species model (3.2 MB)
├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB)
├── disease-heads/ # One ONNX per species
│ ├── tomato-int8.onnx
│ ├── acorn-squash-int8.onnx
│ └── ...
├── disease-heads-list.json # Maps species ID → ONNX file path
├── ood-detector.pkl # Mahalanobis parameters for OOD
└── onnx-metadata.json # Input/output shapes, versions
Steps
3.1 Export backbone + species head as single ONNX
import torch
import onnx
from pathlib import Path
model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt")
model.eval()
# Export end-to-end species model (backbone + species head)
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model, # Combined forward: backbone + species_head
dummy,
"public/models/swin-species.onnx",
input_names=["input"],
output_names=["species_logits", "embedding"],
dynamic_axes={
"input": {0: "batch_size"},
"species_logits": {0: "batch_size"},
"embedding": {0: "batch_size"},
},
opset_version=17,
do_constant_folding=True,
)
Key: Export the 768-dim embedding as a second output — the server needs it to route to the correct disease head.
3.2 Export disease heads individually
Each disease head is a simple nn.Linear(768, N_diseases). Export as a mini-ONNX that takes the embedding and returns disease logits:
for species_name, head in model.disease_heads.items():
dummy_embed = torch.randn(1, 768)
torch.onnx.export(
head, dummy_embed,
f"public/models/disease-heads/{species_name}.onnx",
input_names=["embedding"],
output_names=["disease_logits"],
dynamic_axes={"embedding": {0: "batch_size"}},
opset_version=17,
)
Total: ~320 small ONNX files, each ~50-200 KB.
3.3 INT8 Quantization
Use ONNX Runtime's quantization tooling:
from onnxruntime.quantization import quantize_dynamic, QuantType
# Quantize species model
quantize_dynamic(
"public/models/swin-species.onnx",
"public/models/swin-species-int8.onnx",
weight_type=QuantType.QInt8,
)
# Quantize disease heads (batch)
for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")):
quantize_dynamic(
str(onnx_path),
str(onnx_path.with_suffix("-int8.onnx")),
weight_type=QuantType.QInt8,
)
Accuracy impact: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers.
3.4 OOD Detector
Train a Mahalanobis distance-based OOD detector on the training set embeddings:
import numpy as np
from scipy.spatial.distance import mahalanobis
# Collect embeddings from training set
embeddings = []
for batch in val_dataloader:
with torch.no_grad():
_, emb = model(batch["image"])
embeddings.append(emb.numpy())
embeddings = np.vstack(embeddings)
# Fit multivariate Gaussian
mean = np.mean(embeddings, axis=0)
cov = np.cov(embeddings, rowvar=False)
inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))
# Save for inference
import pickle
with open("public/models/ood-detector.pkl", "wb") as f:
pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f)
The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD.
3.5 Accuracy verification
Before committing to ONNX, verify against PyTorch:
import onnxruntime as ort
# Compare PyTorch vs ONNX outputs
pytorch_out = model(sample_image)
ort_out = ort.InferenceSession("swin-species-int8.onnx").run(
["species_logits"], {"input": sample_image.numpy()}
)
max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0]))
assert max_diff < 0.01, f"ONNX mismatch: {max_diff}"
Edge Cases & Gotchas
- ONNX opset compatibility: Some
timmmodel ops (likerollin Swin attention) may need opset ≥17. If export fails, try opset 18 or 19. - Dynamic axes: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions.
- Disease head routing: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly.
- Strix Halo ROCm + ONNX: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both.
- Disease head file count: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference).
Verification
- ONNX species model output matches PyTorch output (max diff < 0.01)
- INT8 accuracy within 1% of FP16 on val set (sample 10K images)
- ONNX model loads in ONNX Runtime without errors
- All 320+ disease heads export successfully
- OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision
- ONNX model size < 5MB (INT8) for species, < 200KB per disease head
- Inference on CPU (Strix Halo) < 200ms for species + disease combined