Comments (19)
@justinchuby Is this what you mean?
exported_program = torch.export.export(model, (data,))
exported_program._graph_module.print_readable()
which gives
class GraphModule(torch.nn.Module):
def forward(self, p_convinit_conv_weight: "f32[16, 4, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm1_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm1_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm2_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm2_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_down_layers_1_0_conv_weight: "f32[32, 16, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_down_layers_2_0_conv_weight: "f32[64, 32, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_down_layers_3_0_conv_weight: "f32[128, 64, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_up_samples_0_0_conv_weight: "f32[64, 128, 1, 1, 1]", p_getattr_l__self___up_layers_0___0___norm1_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm1_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___up_layers_0___0___norm2_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm2_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_up_samples_1_0_conv_weight: "f32[32, 64, 1, 1, 1]", p_getattr_l__self___up_layers_1___0___norm1_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm1_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___up_layers_1___0___norm2_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm2_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_up_samples_2_0_conv_weight: "f32[16, 32, 1, 1, 1]", p_getattr_l__self___up_layers_2___0___norm1_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm1_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___up_layers_2___0___norm2_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm2_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_conv_final_0_weight: "f32[16]", p_conv_final_0_bias: "f32[16]", p_conv_final_2_conv_weight: "f32[3, 16, 1, 1, 1]", p_conv_final_2_conv_bias: "f32[3]", x: "f32[1, 4, 224, 224, 128]"):
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:157 in encode, code: x = self.convInit(x)
conv3d: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(x, p_convinit_conv_weight, None, [1, 1, 1], [1, 1, 1]); x = p_convinit_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:159 in encode, code: x = self.dropout(x)
feature_dropout: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.feature_dropout.default(conv3d, 0.2, False); conv3d = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(feature_dropout, 8, p_getattr_l__self___down_layers_0___1___norm1_weight, p_getattr_l__self___down_layers_0___1___norm1_bias); p_getattr_l__self___down_layers_0___1___norm1_weight = p_getattr_l__self___down_layers_0___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm); group_norm = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu, p_getattr_l__self___down_layers_0___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu = p_getattr_l__self___down_layers_0___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_1, 8, p_getattr_l__self___down_layers_0___1___norm2_weight, p_getattr_l__self___down_layers_0___1___norm2_bias); conv3d_1 = p_getattr_l__self___down_layers_0___1___norm2_weight = p_getattr_l__self___down_layers_0___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_1); group_norm_1 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_1, p_getattr_l__self___down_layers_0___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_1 = p_getattr_l__self___down_layers_0___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_2, feature_dropout); conv3d_2 = feature_dropout = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(add, p_down_layers_1_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_1_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_3, 8, p_getattr_l__self___down_layers_1___1___norm1_weight, p_getattr_l__self___down_layers_1___1___norm1_bias); p_getattr_l__self___down_layers_1___1___norm1_weight = p_getattr_l__self___down_layers_1___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_2); group_norm_2 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_2, p_getattr_l__self___down_layers_1___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_2 = p_getattr_l__self___down_layers_1___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_4, 8, p_getattr_l__self___down_layers_1___1___norm2_weight, p_getattr_l__self___down_layers_1___1___norm2_bias); conv3d_4 = p_getattr_l__self___down_layers_1___1___norm2_weight = p_getattr_l__self___down_layers_1___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_3); group_norm_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_3, p_getattr_l__self___down_layers_1___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_3 = p_getattr_l__self___down_layers_1___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_5, conv3d_3); conv3d_5 = conv3d_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_1, 8, p_getattr_l__self___down_layers_1___2___norm1_weight, p_getattr_l__self___down_layers_1___2___norm1_bias); p_getattr_l__self___down_layers_1___2___norm1_weight = p_getattr_l__self___down_layers_1___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_4); group_norm_4 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_6: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_4, p_getattr_l__self___down_layers_1___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_4 = p_getattr_l__self___down_layers_1___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_6, 8, p_getattr_l__self___down_layers_1___2___norm2_weight, p_getattr_l__self___down_layers_1___2___norm2_bias); conv3d_6 = p_getattr_l__self___down_layers_1___2___norm2_weight = p_getattr_l__self___down_layers_1___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_5); group_norm_5 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_7: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_5, p_getattr_l__self___down_layers_1___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_5 = p_getattr_l__self___down_layers_1___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_7, add_1); conv3d_7 = add_1 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_2, p_down_layers_2_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_2_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_8, 8, p_getattr_l__self___down_layers_2___1___norm1_weight, p_getattr_l__self___down_layers_2___1___norm1_bias); p_getattr_l__self___down_layers_2___1___norm1_weight = p_getattr_l__self___down_layers_2___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_6); group_norm_6 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_6, p_getattr_l__self___down_layers_2___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_6 = p_getattr_l__self___down_layers_2___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_9, 8, p_getattr_l__self___down_layers_2___1___norm2_weight, p_getattr_l__self___down_layers_2___1___norm2_bias); conv3d_9 = p_getattr_l__self___down_layers_2___1___norm2_weight = p_getattr_l__self___down_layers_2___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_7); group_norm_7 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_7, p_getattr_l__self___down_layers_2___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_7 = p_getattr_l__self___down_layers_2___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_3: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_10, conv3d_8); conv3d_10 = conv3d_8 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_3, 8, p_getattr_l__self___down_layers_2___2___norm1_weight, p_getattr_l__self___down_layers_2___2___norm1_bias); p_getattr_l__self___down_layers_2___2___norm1_weight = p_getattr_l__self___down_layers_2___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_8); group_norm_8 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_11: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_8, p_getattr_l__self___down_layers_2___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_8 = p_getattr_l__self___down_layers_2___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_11, 8, p_getattr_l__self___down_layers_2___2___norm2_weight, p_getattr_l__self___down_layers_2___2___norm2_bias); conv3d_11 = p_getattr_l__self___down_layers_2___2___norm2_weight = p_getattr_l__self___down_layers_2___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_9); group_norm_9 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_12: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_9, p_getattr_l__self___down_layers_2___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_9 = p_getattr_l__self___down_layers_2___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_4: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_12, add_3); conv3d_12 = add_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_4, p_down_layers_3_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_3_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_13, 8, p_getattr_l__self___down_layers_3___1___norm1_weight, p_getattr_l__self___down_layers_3___1___norm1_bias); p_getattr_l__self___down_layers_3___1___norm1_weight = p_getattr_l__self___down_layers_3___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_10); group_norm_10 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_10, p_getattr_l__self___down_layers_3___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_10 = p_getattr_l__self___down_layers_3___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_14, 8, p_getattr_l__self___down_layers_3___1___norm2_weight, p_getattr_l__self___down_layers_3___1___norm2_bias); conv3d_14 = p_getattr_l__self___down_layers_3___1___norm2_weight = p_getattr_l__self___down_layers_3___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_11); group_norm_11 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_11, p_getattr_l__self___down_layers_3___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_11 = p_getattr_l__self___down_layers_3___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_5: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_15, conv3d_13); conv3d_15 = conv3d_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_5, 8, p_getattr_l__self___down_layers_3___2___norm1_weight, p_getattr_l__self___down_layers_3___2___norm1_bias); p_getattr_l__self___down_layers_3___2___norm1_weight = p_getattr_l__self___down_layers_3___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_12); group_norm_12 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_12, p_getattr_l__self___down_layers_3___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_12 = p_getattr_l__self___down_layers_3___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_16, 8, p_getattr_l__self___down_layers_3___2___norm2_weight, p_getattr_l__self___down_layers_3___2___norm2_bias); conv3d_16 = p_getattr_l__self___down_layers_3___2___norm2_weight = p_getattr_l__self___down_layers_3___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_13); group_norm_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_13, p_getattr_l__self___down_layers_3___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_13 = p_getattr_l__self___down_layers_3___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_6: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_17, add_5); conv3d_17 = add_5 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_6, 8, p_getattr_l__self___down_layers_3___3___norm1_weight, p_getattr_l__self___down_layers_3___3___norm1_bias); p_getattr_l__self___down_layers_3___3___norm1_weight = p_getattr_l__self___down_layers_3___3___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_14); group_norm_14 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_18: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_14, p_getattr_l__self___down_layers_3___3___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_14 = p_getattr_l__self___down_layers_3___3___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_18, 8, p_getattr_l__self___down_layers_3___3___norm2_weight, p_getattr_l__self___down_layers_3___3___norm2_bias); conv3d_18 = p_getattr_l__self___down_layers_3___3___norm2_weight = p_getattr_l__self___down_layers_3___3___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_15); group_norm_15 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_19: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_15, p_getattr_l__self___down_layers_3___3___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_15 = p_getattr_l__self___down_layers_3___3___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_7: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_19, add_6); conv3d_19 = add_6 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_7, 8, p_getattr_l__self___down_layers_3___4___norm1_weight, p_getattr_l__self___down_layers_3___4___norm1_bias); p_getattr_l__self___down_layers_3___4___norm1_weight = p_getattr_l__self___down_layers_3___4___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_16); group_norm_16 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_20: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_16, p_getattr_l__self___down_layers_3___4___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_16 = p_getattr_l__self___down_layers_3___4___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_20, 8, p_getattr_l__self___down_layers_3___4___norm2_weight, p_getattr_l__self___down_layers_3___4___norm2_bias); conv3d_20 = p_getattr_l__self___down_layers_3___4___norm2_weight = p_getattr_l__self___down_layers_3___4___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_17); group_norm_17 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_21: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_17, p_getattr_l__self___down_layers_3___4___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_17 = p_getattr_l__self___down_layers_3___4___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_8: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_21, add_7); conv3d_21 = add_7 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_22: "f32[1, 64, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_8, p_up_samples_0_0_conv_weight); add_8 = p_up_samples_0_0_conv_weight = None
upsample_trilinear3d: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_22, None, False, [2.0, 2.0, 2.0]); conv3d_22 = None
add_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(upsample_trilinear3d, add_4); upsample_trilinear3d = add_4 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_9, 8, p_getattr_l__self___up_layers_0___0___norm1_weight, p_getattr_l__self___up_layers_0___0___norm1_bias); p_getattr_l__self___up_layers_0___0___norm1_weight = p_getattr_l__self___up_layers_0___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_18); group_norm_18 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_23: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_18, p_getattr_l__self___up_layers_0___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_18 = p_getattr_l__self___up_layers_0___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_23, 8, p_getattr_l__self___up_layers_0___0___norm2_weight, p_getattr_l__self___up_layers_0___0___norm2_bias); conv3d_23 = p_getattr_l__self___up_layers_0___0___norm2_weight = p_getattr_l__self___up_layers_0___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_19); group_norm_19 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_24: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_19, p_getattr_l__self___up_layers_0___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_19 = p_getattr_l__self___up_layers_0___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_24, add_9); conv3d_24 = add_9 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_25: "f32[1, 32, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_10, p_up_samples_1_0_conv_weight); add_10 = p_up_samples_1_0_conv_weight = None
upsample_trilinear3d_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_25, None, False, [2.0, 2.0, 2.0]); conv3d_25 = None
add_11: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_1, add_2); upsample_trilinear3d_1 = add_2 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_11, 8, p_getattr_l__self___up_layers_1___0___norm1_weight, p_getattr_l__self___up_layers_1___0___norm1_bias); p_getattr_l__self___up_layers_1___0___norm1_weight = p_getattr_l__self___up_layers_1___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_20); group_norm_20 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_26: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_20, p_getattr_l__self___up_layers_1___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_20 = p_getattr_l__self___up_layers_1___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_26, 8, p_getattr_l__self___up_layers_1___0___norm2_weight, p_getattr_l__self___up_layers_1___0___norm2_bias); conv3d_26 = p_getattr_l__self___up_layers_1___0___norm2_weight = p_getattr_l__self___up_layers_1___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_21); group_norm_21 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_27: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_21, p_getattr_l__self___up_layers_1___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_21 = p_getattr_l__self___up_layers_1___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_12: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_27, add_11); conv3d_27 = add_11 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_28: "f32[1, 16, 112, 112, 64]" = torch.ops.aten.conv3d.default(add_12, p_up_samples_2_0_conv_weight); add_12 = p_up_samples_2_0_conv_weight = None
upsample_trilinear3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_28, None, False, [2.0, 2.0, 2.0]); conv3d_28 = None
add_13: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_2, add); upsample_trilinear3d_2 = add = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_13, 8, p_getattr_l__self___up_layers_2___0___norm1_weight, p_getattr_l__self___up_layers_2___0___norm1_bias); p_getattr_l__self___up_layers_2___0___norm1_weight = p_getattr_l__self___up_layers_2___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_22); group_norm_22 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_29: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_22, p_getattr_l__self___up_layers_2___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_22 = p_getattr_l__self___up_layers_2___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_29, 8, p_getattr_l__self___up_layers_2___0___norm2_weight, p_getattr_l__self___up_layers_2___0___norm2_bias); conv3d_29 = p_getattr_l__self___up_layers_2___0___norm2_weight = p_getattr_l__self___up_layers_2___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_23); group_norm_23 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_30: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_23, p_getattr_l__self___up_layers_2___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_23 = p_getattr_l__self___up_layers_2___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_14: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_30, add_13); conv3d_30 = add_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:175 in decode, code: x = self.conv_final(x)
group_norm_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_14, 8, p_conv_final_0_weight, p_conv_final_0_bias); add_14 = p_conv_final_0_weight = p_conv_final_0_bias = None
relu_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_24); group_norm_24 = None
conv3d_31: "f32[1, 3, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_24, p_conv_final_2_conv_weight, p_conv_final_2_conv_bias); relu_24 = p_conv_final_2_conv_weight = p_conv_final_2_conv_bias = None
return (conv3d_31,)
from onnxscript.
It's possible that the upsample op was somehow decomposed by PyTorch. I will look deeper.
from onnxscript.
I also tried to get more info on which part of dynamo/onnxscript might be responsible for this.
If I run
scripted_model = torch.jit.script(model)
print(scripted_model.graph)
I get this:
graph(%self : __torch__.monai.networks.nets.segresnet.SegResNet,
%x.1 : Tensor):
%3 : (Tensor, Tensor[]) = prim::CallMethod[name="encode"](%self, %x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:180:20
%x.5 : Tensor, %down_x.1 : Tensor[] = prim::TupleUnpack(%3)
= aten::reverse(%down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:181:8
%x.9 : Tensor = prim::CallMethod[name="decode"](%self, %x.5, %down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:183:12
return (%x.9)
If I run
gm, _ = torch._dynamo.export(model)(data)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()
I get an error:
Traceback (most recent call last):
File "/ws/dynamo/0501/export_SegResNet.py", line 31, in <module>
gm, _ = torch._dynamo.export(model)(data)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1282, in inner
dim_constraints.solve()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 1772, in solve
tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[4]"
from onnxscript.
Thanks for catching this. Very intriguing. Will take a look!
from onnxscript.
from onnxscript.
@yuanyao-nv could you obtain the graph module from torch.export.export
and post it here?
from onnxscript.
Yes, thank you
from onnxscript.
That's very strange. If you run torch.onnx.dynamo_export(exported_program, ...)
, do you get the same graph?
from onnxscript.
@justinchuby Is this the procedure you're suggesting?
exported_program = torch.export.export(model, (data,))
exported_program._graph_module.print_readable()
export_output = torch.onnx.dynamo_export(exported_program, data)
export_output.save('Clara_SegResNet_dynamo1.onnx')
The exported UpSample module looks about the same as before, still a very big graph.
In addition, the weights in the model appear as extra inputs, giving rise to tens of extra model inputs. Similar to this bug pytorch/pytorch#126071
from onnxscript.
I don't see any resize ops, which is puzzling. Could you share the onnx model? You may remove the weights if it is too big
from onnxscript.
I was expecting to see this function:
from onnxscript.
@justinchuby I uploaded the two versions of the model here: https://drive.google.com/drive/folders/1s1lhKRuG6fOZmD4IjZvN_zlWfIxPB_8w?usp=sharing
from onnxscript.
I have found that in general case, one has to run exported_program.run_decompositions() before applying dynamo_export().
That may in fact fold some operations. @yuanyao-nv can you try that ?
from onnxscript.
Thanks. We will be creating a series of changes to the exporter to support ExportedPrograms properly, including handling of the weights.
from onnxscript.
@borisfom I tried running run_decompositions()
but it didn't do anything for this particular subgraph.
exported_program = torch.export.export(model, args=(data,))
exported_program.run_decompositions()
export_output = torch.onnx.dynamo_export(exported_program, data)
export_output.save('Clara_SegResNet_dynamo.onnx')
from onnxscript.
There are two issues here:
- Unsupported
upsample_trilinear_vec
op: #1592 - Dynamo forces to decompose upsample related ops for some reasons (related issue: pytorch/pytorch#115883 and pytorch/pytorch#116684). pytorch/pytorch#128259
from onnxscript.
Hi @yuanyao-nv,
This one should be fixed when you call torch.onnx.dynamo_export
with nn.Module. However, if you call torch.export.export first, it's going to be decomposed to the big subgraph you had. This decomposition is forced by dynamo for some reasons. Feel free to open an issue like pytorch/pytorch#115883.
cc @gramalingam @justinchuby @xadupre This forcing decomposition would need us to maybe rewriting them as patterns. It will come back to us once we rely on torch.export.export.
from onnxscript.
@titaiwangms Thanks for the update.
What's a good way to test it out?
If I rerun the export script in the description using the latest torch nightly build (2.5.0.dev20240617+cu121
) I actually hit another error
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
).export()
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1236, in export
graph_module = self.options.fx_tracer.generate_fx(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 214, in generate_fx
graph_module, graph_guard = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 169, in wrapped
return output_adapter.apply(model_func(*args, **kwargs), model=model)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2462, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 356, in call_function
return super().call_function(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
return super().call_function(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2677, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2793, in inline_call_
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 234, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 962, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 941, in _handle_insert_op_in_graph
return wrap_fx_proxy(tx, proxy)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1759, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1846, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
ret_val = wrap_fake_exception(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
return fn()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node
return node.target(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
return self_._op(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 266, in _fn
result = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 138, in _fn
result = fn(**bound.arguments)
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1080, in add
a, b = _maybe_broadcast(a, b)
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 419, in _maybe_broadcast
common_shape = _broadcast_shapes(
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 408, in _broadcast_shapes
raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='cuda:0', size=(1, 64, 28, 56, 56, 32),
grad_fn=<WarnNotImplemented>), FakeTensor(..., device='cuda:0', size=(1, 64, 56, 56, 32),
grad_fn=<AddBackward0>)), **{}):
Attempting to broadcast a dimension of length 64 at -4! Mismatching argument at index 1 had torch.Size([1, 64, 56, 56, 32]); but expected shape should be broadcastable to [1, 64, 28, 56, 56, 32]
from user code:
File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 183, in forward
x = self.decode(x, down_x)
File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 171, in decode
x = up(x) + down_x[i + 1]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Do you know why fake tensor is being used in the latest torch version?
I also tried exporting just a nn.Upsample
function
def f(x):
m = torch.nn.Upsample(size=(10), mode='linear')
return m(x)
x = torch.randn(2, 5, 5)
export_output = torch.onnx.dynamo_export(f, x)
export_output.save('Upsample.onnx')
The exported graph looks reasonable. Is this what you'd expect?
from onnxscript.
Filed a separate issue to track the above fake tensor broadcast error pytorch/pytorch#129534
from onnxscript.
Related Issues (20)
- [ONNX] Implement <OpOverload(op='aten.pixel_unshuffle', overload='default')>
- [torchlib] Implement <OpOverload(op='aten.repeat_interleave', overload='Tensor')>
- [torchlib] Implement <OpOverload(op='aten.scatter', overload='src')>
- [torchlib] Implement <OpOverload(op='aten.scatter', overload='value')>
- [torchlib] Implement <OpOverload(op='aten.silu', overload='default')>
- [torchlib] Implement <OpOverload(op='aten.sort', overload='default')>
- [torchlib] Implement <OpOverload(op='aten.std', overload='correction')>
- [torchlib] Implement <OpOverload(op='aten.std_mean', overload='correction')>
- [torchlib] Implement <OpOverload(op='aten.sym_size', overload='int')>
- [torchlib] Implement <OpOverload(op='aten.take', overload='default')>
- [torchlib] Implement <OpOverload(op='aten.unsafe_split', overload='Tensor')>
- [torchlib] Implement <OpOverload(op='torchvision.nms', overload='default')>
- [torchlib] Implement <OpOverload(op='torchvision.roi_align', overload='default')>
- [torchlib] Implement <OpOverload(op='torchvision.roi_pool', overload='default')>
- [exporter] Create a pass to turn tensors into external tensors
- [core] Migrate OpSignature to ONNX Script
- [exporter] Create an IR modularization pass
- [exporter] Create an inliner pass for the IR
- [IR] Create a utility for merging models
- Use IR in the optimizer HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from onnxscript.