Files
plant-disease-id/tasks/hierarchical-model-upgrade/03-export-quantization.md

5.7 KiB
Raw Blame History

Phase 3 — ONNX Export & Quantization

Blocked by: Phase 2 (trained model) Blocks: Phase 4 (server inference) Est. time: 1-2 days Machine: Any (RTX 3090 recommended for ONNX GPU validation)

Objective

Export the trained PyTorch model to ONNX format, apply INT8 quantization, and verify accuracy before deployment.

Deliverables

public/models/
├── swin-species.onnx              # FP16 species model (3.2 MB)
├── swin-species-int8.onnx         # INT8 quantized species model (1.1 MB)
├── disease-heads/                  # One ONNX per species
│   ├── tomato-int8.onnx
│   ├── acorn-squash-int8.onnx
│   └── ...
├── disease-heads-list.json        # Maps species ID → ONNX file path
├── ood-detector.pkl               # Mahalanobis parameters for OOD
└── onnx-metadata.json             # Input/output shapes, versions

Steps

3.1 Export backbone + species head as single ONNX

import torch
import onnx
from pathlib import Path

model = load_model_from_checkpoint("checkpoints/hierarchical_full/epoch=10-best.ckpt")
model.eval()

# Export end-to-end species model (backbone + species head)
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model,                          # Combined forward: backbone + species_head
    dummy,
    "public/models/swin-species.onnx",
    input_names=["input"],
    output_names=["species_logits", "embedding"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "species_logits": {0: "batch_size"},
        "embedding": {0: "batch_size"},
    },
    opset_version=17,
    do_constant_folding=True,
)

Key: Export the 768-dim embedding as a second output — the server needs it to route to the correct disease head.

3.2 Export disease heads individually

Each disease head is a simple nn.Linear(768, N_diseases). Export as a mini-ONNX that takes the embedding and returns disease logits:

for species_name, head in model.disease_heads.items():
    dummy_embed = torch.randn(1, 768)
    torch.onnx.export(
        head, dummy_embed,
        f"public/models/disease-heads/{species_name}.onnx",
        input_names=["embedding"],
        output_names=["disease_logits"],
        dynamic_axes={"embedding": {0: "batch_size"}},
        opset_version=17,
    )

Total: ~320 small ONNX files, each ~50-200 KB.

3.3 INT8 Quantization

Use ONNX Runtime's quantization tooling:

from onnxruntime.quantization import quantize_dynamic, QuantType

# Quantize species model
quantize_dynamic(
    "public/models/swin-species.onnx",
    "public/models/swin-species-int8.onnx",
    weight_type=QuantType.QInt8,
)

# Quantize disease heads (batch)
for onnx_path in sorted(Path("public/models/disease-heads").glob("*.onnx")):
    quantize_dynamic(
        str(onnx_path),
        str(onnx_path.with_suffix("-int8.onnx")),
        weight_type=QuantType.QInt8,
    )

Accuracy impact: INT8 quantization typically causes <1% accuracy drop when using dynamic quantization on the linear/embedding layers. The Swin-Tiny attention layers are less affected than CNN layers.

3.4 OOD Detector

Train a Mahalanobis distance-based OOD detector on the training set embeddings:

import numpy as np
from scipy.spatial.distance import mahalanobis

# Collect embeddings from training set
embeddings = []
for batch in val_dataloader:
    with torch.no_grad():
        _, emb = model(batch["image"])
        embeddings.append(emb.numpy())
embeddings = np.vstack(embeddings)

# Fit multivariate Gaussian
mean = np.mean(embeddings, axis=0)
cov = np.cov(embeddings, rowvar=False)
inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))

# Save for inference
import pickle
with open("public/models/ood-detector.pkl", "wb") as f:
    pickle.dump({"mean": mean, "inv_cov": inv_cov, "threshold": 95.0}, f)

The threshold (95th percentile of training set Mahalanobis distances) rejects non-plant images. If a test image has a distance > threshold, reject it as OOD.

3.5 Accuracy verification

Before committing to ONNX, verify against PyTorch:

import onnxruntime as ort

# Compare PyTorch vs ONNX outputs
pytorch_out = model(sample_image)
ort_out = ort.InferenceSession("swin-species-int8.onnx").run(
    ["species_logits"], {"input": sample_image.numpy()}
)

max_diff = np.max(np.abs(pytorch_out.numpy() - ort_out[0]))
assert max_diff < 0.01, f"ONNX mismatch: {max_diff}"

Edge Cases & Gotchas

  • ONNX opset compatibility: Some timm model ops (like roll in Swin attention) may need opset ≥17. If export fails, try opset 18 or 19.
  • Dynamic axes: Resize input to 224×224 on the client; ONNX models should accept variable batch sizes but fixed spatial dimensions.
  • Disease head routing: The server must map the predicted species index to a disease head ONNX file. This mapping must match the training class ordering exactly.
  • Strix Halo ROCm + ONNX: ONNX Runtime supports ROCm via DirectML or MIGraphX backends. The default CPU path may be faster for INT8 models if GPU kernels are missing. Test both.
  • Disease head file count: 320+ small files may be slow to enumerate on cold start. Consider batching all disease heads into a single ONNX with a species index input for routing (more complex but faster at inference).

Verification

  • ONNX species model output matches PyTorch output (max diff < 0.01)
  • INT8 accuracy within 1% of FP16 on val set (sample 10K images)
  • ONNX model loads in ONNX Runtime without errors
  • All 320+ disease heads export successfully
  • OOD detector rejects obvious non-plant images (rocks, buildings, people) with ≥99% precision
  • ONNX model size < 5MB (INT8) for species, < 200KB per disease head
  • Inference on CPU (Strix Halo) < 200ms for species + disease combined