-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model_to_onnx.py
35 lines (27 loc) · 1019 Bytes
/
export_model_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from model import NvidiaModel
def main():
# Ensure model and input are on the same device
device = torch.device('cpu')
# Load your pretrained PyTorch model here
model = NvidiaModel()
model.load_state_dict(torch.load("./save/model.pt", map_location=device))
model.to(device) # Move the model to the device
model.eval() # Set the model to evaluation mode
# Create a dummy input that matches the input format of the model
# The input format of the model is (batch_size, channels, height, width)
# for this case, width = 200, height = 66, channels = 3
dummy_input = torch.randn(1, 3, 66, 200, device=device)
# Export the model to an ONNX file
torch.onnx.export(
model,
dummy_input,
'./save/drive_net_model.onnx',
verbose=True,
input_names=['input'],
output_names=['output'],
opset_version=11,
)
print('Model exported to "drive_net_model.onnx".')
if __name__ == '__main__':
main()