-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexport_onnx.py
More file actions
110 lines (85 loc) · 3.33 KB
/
export_onnx.py
File metadata and controls
110 lines (85 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Export the trained Traffic Light Classifier to ONNX.
Uses **fixed batch size 1** (input ``(1, 3, H, W)``, output ``(1, num_classes)``) so
NVIDIA DriveWorks ``tensorRT_optimization`` can build without dynamic-shape profiles.
Usage
-----
python export_onnx.py \
--checkpoint outputs/tl_classifier_lisa/checkpoints/best.pth \
--output tl_classifier.onnx
Verify with:
python export_onnx.py \
--checkpoint outputs/tl_classifier_lisa/checkpoints/best.pth \
--output tl_classifier.onnx \
--verify
Build TensorRT engine (on target device):
trtexec --onnx=tl_classifier.onnx \
--saveEngine=tl_classifier.engine \
--fp16 \
--useDLACore=0 --allowGPUFallback
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).parent))
from models.traffic_light_classifier import (
TrafficLightClassifier,
TL_CLASSES, TL_CROP_H, TL_CROP_W,
)
def export(checkpoint_path: str, output_path: str, opset: int = 17):
device = torch.device('cpu')
model = TrafficLightClassifier(pretrained=False, num_classes=len(TL_CLASSES))
ckpt = torch.load(checkpoint_path, map_location=device)
state = ckpt.get('model_state_dict', ckpt)
model.load_state_dict(state)
model.eval()
dummy = torch.randn(1, 3, TL_CROP_H, TL_CROP_W)
torch.onnx.export(
model,
dummy,
output_path,
opset_version=opset,
input_names=['crop'],
output_names=['logits'],
)
print(f"[Export] ONNX saved: {output_path}")
print(f"[Export] Input: crop (1, 3, {TL_CROP_H}, {TL_CROP_W})")
print(f"[Export] Output: logits (1, {len(TL_CLASSES)}) → {TL_CLASSES}")
return output_path
def verify(checkpoint_path: str, onnx_path: str):
try:
import onnxruntime as ort
except ImportError:
print("[Verify] onnxruntime not installed — skipping verification.")
print(" pip install onnxruntime")
return
device = torch.device('cpu')
model = TrafficLightClassifier(pretrained=False, num_classes=len(TL_CLASSES))
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt.get('model_state_dict', ckpt))
model.eval()
dummy = torch.randn(1, 3, TL_CROP_H, TL_CROP_W)
with torch.no_grad():
pt_out = model(dummy).numpy()
sess = ort.InferenceSession(onnx_path)
ort_out = sess.run(None, {'crop': dummy.numpy()})[0]
diff = np.abs(pt_out - ort_out).max()
print(f"[Verify] Max abs diff PyTorch vs ONNX: {diff:.6e}")
if diff < 1e-4:
print("[Verify] PASSED — outputs match within tolerance.")
else:
print("[Verify] WARNING — outputs diverge. Check export settings.")
def main():
p = argparse.ArgumentParser(description='Export TL classifier to ONNX')
p.add_argument('--checkpoint', required=True, help='Path to best.pth')
p.add_argument('--output', default='tl_classifier.onnx', help='ONNX output path')
p.add_argument('--opset', type=int, default=13, help='ONNX opset version')
p.add_argument('--verify', action='store_true', help='Verify ONNX vs PyTorch')
args = p.parse_args()
export(args.checkpoint, args.output, args.opset)
if args.verify:
verify(args.checkpoint, args.output)
if __name__ == '__main__':
main()