Files
plant-disease-id/tasks/hierarchical-model-upgrade/03-export-quantization.md

166 lines
5.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Phase 3 — ONNX Export & Quantization
**Blocked by**: Phase 2 (trained model)
**Blocks**: Phase 4 (server inference)
**Est. time**: 1-2 days
**Machine**: Any (RTX 3090 recommended for ONNX GPU validation)
## Objective
Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment.
## Deliverables
```
public/models/
├── swin-species.onnx # FP16 species model (3.2 MB)
├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB)
├── disease-heads/ # One ONNX per species
│ ├── tomato-int8.onnx
│ ├── acorn-squash-int8.onnx
│ └── ...
├── disease-heads-list.json # Maps species ID → ONNX file path
├── ood-detector.pkl # Mahalanobis parameters for OOD
└── onnx-metadata.json # Input/output shapes, versions
```
## Steps
### 3.1 Export backbone + species head as single ONNX
```python
import torch
import onnx
from pathlib import Path
model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt")
model.eval()
# Export end-to-end species model (backbone + species head)
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model, # Combined forward: backbone + species_head
dummy,
"public/models/swin-species.onnx",
input_names=["input"],
output_names=["species_logits", "embedding"],
dynamic_axes={
"input": {0: "batch_size"},
"species_logits": {0: "batch_size"},
"embedding": {0: "batch_size"},
},
opset_version=17,
do_constant_folding=True,
)
```
**Key**: Export the 768-dim `embedding` as a second output — the server needs it to route to the correct disease head.
### 3.2 Export disease heads individually
Each disease head is a simple `nn.Linear(768, N_diseases)`. Export as a mini-ONNX that takes the embedding and returns disease logits:
```python
for species_name, head in model.disease_heads.items():
dummy_embed = torch.randn(1, 768)
torch.onnx.export(
head, dummy_embed,
f"public/models/disease-heads/{species_name}.onnx",
input_names=["embedding"],
output_names=["disease_logits"],
dynamic_axes={"embedding": {0: "batch_size"}},
opset_version=17,
)
```
**Total**: ~320 small ONNX files, each ~50-200 KB.
### 3.3 INT8 Quantization
Use ONNX Runtime's quantization tooling:
```python
from onnxruntime.quantization import quantize_dynamic, QuantType
# Quantize species model
quantize_dynamic(
"public/models/swin-species.onnx",
"public/models/swin-species-int8.onnx",
weight_type=QuantType.QInt8,
)
# Quantize disease heads (batch)
for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")):
quantize_dynamic(
str(onnx_path),
str(onnx_path.with_suffix("-int8.onnx")),
weight_type=QuantType.QInt8,
)
```
**Accuracy impact**: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers.
### 3.4 OOD Detector
Train a Mahalanobis distance-based OOD detector on the training set embeddings:
```python
import numpy as np
from scipy.spatial.distance import mahalanobis
# Collect embeddings from training set
embeddings = []
for batch in val_dataloader:
with torch.no_grad():
_, emb = model(batch["image"])
embeddings.append(emb.numpy())
embeddings = np.vstack(embeddings)
# Fit multivariate Gaussian
mean = np.mean(embeddings, axis=0)
cov = np.cov(embeddings, rowvar=False)
inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))
# Save for inference
import pickle
with open("public/models/ood-detector.pkl", "wb") as f:
pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f)
```
The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD.
### 3.5 Accuracy verification
Before committing to ONNX, verify against PyTorch:
```python
import onnxruntime as ort
# Compare PyTorch vs ONNX outputs
pytorch_out = model(sample_image)
ort_out = ort.InferenceSession("swin-species-int8.onnx").run(
["species_logits"], {"input": sample_image.numpy()}
)
max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0]))
assert max_diff < 0.01, f"ONNX mismatch: {max_diff}"
```
## Edge Cases & Gotchas
- **ONNX opset compatibility**: Some `timm` model ops (like `roll` in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19.
- **Dynamic axes**: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions.
- **Disease head routing**: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly.
- **Strix Halo ROCm + ONNX**: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both.
- **Disease head file count**: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference).
## Verification
- [ ] ONNX species model output matches PyTorch output (max diff < 0.01)
- [ ] INT8 accuracy within 1% of FP16 on val set (sample 10K images)
- [ ] ONNX model loads in ONNX Runtime without errors
- [ ] All 320+ disease heads export successfully
- [ ] OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision
- [ ] ONNX model size < 5MB (INT8) for species, < 200KB per disease head
- [ ] Inference on CPU (Strix Halo) < 200ms for species + disease combined