File size: 430 Bytes
bddb894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

class AddModel(nn.Module):
    def forward(self, x1, x2):
        return x1 + x2

model = AddModel()
x1 = torch.randn(1, 3)
x2 = torch.randn(1, 3)

torch.onnx.export(
    model,
    (x1, x2),
    "add.onnx",
    input_names=["x1", "x2"],
    output_names=["y"],
    dynamic_axes={
        "x1": {0: "batch"},
        "x2": {0: "batch"},
        "y": {0: "batch"},
    },
    opset_version=18,
)