task to get this here done
This commit is contained in:
165
tasks/hierarchical-model-upgrade/03-export-quantization.md
Normal file
165
tasks/hierarchical-model-upgrade/03-export-quantization.md
Normal file
@@ -0,0 +1,165 @@
|
||||
# Phase 3 — ONNX Export & Quantization
|
||||
|
||||
**Blocked by**: Phase 2 (trained model)
|
||||
**Blocks**: Phase 4 (server inference)
|
||||
**Est. time**: 1-2 days
|
||||
**Machine**: Any (RTX 3090 recommended for ONNX GPU validation)
|
||||
|
||||
## Objective
|
||||
|
||||
Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment.
|
||||
|
||||
## Deliverables
|
||||
|
||||
```
|
||||
public/models/
|
||||
├── swin-species.onnx # FP16 species model (3.2 MB)
|
||||
├── swin-species-int8.onnx # INT8 quantized species model (1.1 MB)
|
||||
├── disease-heads/ # One ONNX per species
|
||||
│ ├── tomato-int8.onnx
|
||||
│ ├── acorn-squash-int8.onnx
|
||||
│ └── ...
|
||||
├── disease-heads-list.json # Maps species ID → ONNX file path
|
||||
├── ood-detector.pkl # Mahalanobis parameters for OOD
|
||||
└── onnx-metadata.json # Input/output shapes, versions
|
||||
```
|
||||
|
||||
## Steps
|
||||
|
||||
### 3.1 Export backbone + species head as single ONNX
|
||||
|
||||
```python
|
||||
import torch
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
|
||||
model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt")
|
||||
model.eval()
|
||||
|
||||
# Export end-to-end species model (backbone + species head)
|
||||
dummy = torch.randn(1, 3, 224, 224)
|
||||
torch.onnx.export(
|
||||
model, # Combined forward: backbone + species_head
|
||||
dummy,
|
||||
"public/models/swin-species.onnx",
|
||||
input_names=["input"],
|
||||
output_names=["species_logits", "embedding"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size"},
|
||||
"species_logits": {0: "batch_size"},
|
||||
"embedding": {0: "batch_size"},
|
||||
},
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Key**: Export the 768-dim `embedding` as a second output — the server needs it to route to the correct disease head.
|
||||
|
||||
### 3.2 Export disease heads individually
|
||||
|
||||
Each disease head is a simple `nn.Linear(768, N_diseases)`. Export as a mini-ONNX that takes the embedding and returns disease logits:
|
||||
|
||||
```python
|
||||
for species_name, head in model.disease_heads.items():
|
||||
dummy_embed = torch.randn(1, 768)
|
||||
torch.onnx.export(
|
||||
head, dummy_embed,
|
||||
f"public/models/disease-heads/{species_name}.onnx",
|
||||
input_names=["embedding"],
|
||||
output_names=["disease_logits"],
|
||||
dynamic_axes={"embedding": {0: "batch_size"}},
|
||||
opset_version=17,
|
||||
)
|
||||
```
|
||||
|
||||
**Total**: ~320 small ONNX files, each ~50-200 KB.
|
||||
|
||||
### 3.3 INT8 Quantization
|
||||
|
||||
Use ONNX Runtime's quantization tooling:
|
||||
|
||||
```python
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
|
||||
# Quantize species model
|
||||
quantize_dynamic(
|
||||
"public/models/swin-species.onnx",
|
||||
"public/models/swin-species-int8.onnx",
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# Quantize disease heads (batch)
|
||||
for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")):
|
||||
quantize_dynamic(
|
||||
str(onnx_path),
|
||||
str(onnx_path.with_suffix("-int8.onnx")),
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
```
|
||||
|
||||
**Accuracy impact**: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers.
|
||||
|
||||
### 3.4 OOD Detector
|
||||
|
||||
Train a Mahalanobis distance-based OOD detector on the training set embeddings:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import mahalanobis
|
||||
|
||||
# Collect embeddings from training set
|
||||
embeddings = []
|
||||
for batch in val_dataloader:
|
||||
with torch.no_grad():
|
||||
_, emb = model(batch["image"])
|
||||
embeddings.append(emb.numpy())
|
||||
embeddings = np.vstack(embeddings)
|
||||
|
||||
# Fit multivariate Gaussian
|
||||
mean = np.mean(embeddings, axis=0)
|
||||
cov = np.cov(embeddings, rowvar=False)
|
||||
inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))
|
||||
|
||||
# Save for inference
|
||||
import pickle
|
||||
with open("public/models/ood-detector.pkl", "wb") as f:
|
||||
pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f)
|
||||
```
|
||||
|
||||
The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD.
|
||||
|
||||
### 3.5 Accuracy verification
|
||||
|
||||
Before committing to ONNX, verify against PyTorch:
|
||||
|
||||
```python
|
||||
import onnxruntime as ort
|
||||
|
||||
# Compare PyTorch vs ONNX outputs
|
||||
pytorch_out = model(sample_image)
|
||||
ort_out = ort.InferenceSession("swin-species-int8.onnx").run(
|
||||
["species_logits"], {"input": sample_image.numpy()}
|
||||
)
|
||||
|
||||
max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0]))
|
||||
assert max_diff < 0.01, f"ONNX mismatch: {max_diff}"
|
||||
```
|
||||
|
||||
## Edge Cases & Gotchas
|
||||
|
||||
- **ONNX opset compatibility**: Some `timm` model ops (like `roll` in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19.
|
||||
- **Dynamic axes**: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions.
|
||||
- **Disease head routing**: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly.
|
||||
- **Strix Halo ROCm + ONNX**: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both.
|
||||
- **Disease head file count**: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference).
|
||||
|
||||
## Verification
|
||||
|
||||
- [ ] ONNX species model output matches PyTorch output (max diff < 0.01)
|
||||
- [ ] INT8 accuracy within 1% of FP16 on val set (sample 10K images)
|
||||
- [ ] ONNX model loads in ONNX Runtime without errors
|
||||
- [ ] All 320+ disease heads export successfully
|
||||
- [ ] OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision
|
||||
- [ ] ONNX model size < 5MB (INT8) for species, < 200KB per disease head
|
||||
- [ ] Inference on CPU (Strix Halo) < 200ms for species + disease combined
|
||||
Reference in New Issue
Block a user