1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import torch import torch.nn as nn
class CustomModel(nn.Module): def __init__(self, input_shape): super(CustomModel, self).__init__() self.input_shape = input_shape self.conv1 = nn.Conv2d(in_channels=input_shape[1], out_channels=16, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1) self.conv4 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1) self.conv5 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1) self.conv6 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
def forward(self, x,which): x = self.conv1(x) x = self.pool(x) x = self.conv2(x) x = self.conv3(x) h1 = x.size(2) //3
x1 = x[:, :, :h1, :] x2 = x[:, :, h1:, :] print(x1.shape)
x1 = self.conv4(x1) x2 = self.conv4(x2) x1 = self.conv5(x1) x2 = self.conv5(x2) x = torch.cat((x1, x2), dim=2) print(x.shape) x = self.conv6(x) return x
input_shape = (1,3, 244, 244) model = CustomModel(input_shape)
dummy_input = torch.randn( 1, 3, 244, 244) torch.onnx.export(model, dummy_input, "my.onnx", verbose=True)
|