Giter VIP home page Giter VIP logo

Comments (19)

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024 1

@justinchuby Is this what you mean?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

which gives

class GraphModule(torch.nn.Module):
    def forward(self, p_convinit_conv_weight: "f32[16, 4, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm1_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm1_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm2_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm2_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_down_layers_1_0_conv_weight: "f32[32, 16, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_down_layers_2_0_conv_weight: "f32[64, 32, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_down_layers_3_0_conv_weight: "f32[128, 64, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_up_samples_0_0_conv_weight: "f32[64, 128, 1, 1, 1]", p_getattr_l__self___up_layers_0___0___norm1_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm1_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___up_layers_0___0___norm2_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm2_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_up_samples_1_0_conv_weight: "f32[32, 64, 1, 1, 1]", p_getattr_l__self___up_layers_1___0___norm1_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm1_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___up_layers_1___0___norm2_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm2_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_up_samples_2_0_conv_weight: "f32[16, 32, 1, 1, 1]", p_getattr_l__self___up_layers_2___0___norm1_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm1_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___up_layers_2___0___norm2_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm2_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_conv_final_0_weight: "f32[16]", p_conv_final_0_bias: "f32[16]", p_conv_final_2_conv_weight: "f32[3, 16, 1, 1, 1]", p_conv_final_2_conv_bias: "f32[3]", x: "f32[1, 4, 224, 224, 128]"):
        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:157 in encode, code: x = self.convInit(x)
        conv3d: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(x, p_convinit_conv_weight, None, [1, 1, 1], [1, 1, 1]);  x = p_convinit_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:159 in encode, code: x = self.dropout(x)
        feature_dropout: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.feature_dropout.default(conv3d, 0.2, False);  conv3d = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(feature_dropout, 8, p_getattr_l__self___down_layers_0___1___norm1_weight, p_getattr_l__self___down_layers_0___1___norm1_bias);  p_getattr_l__self___down_layers_0___1___norm1_weight = p_getattr_l__self___down_layers_0___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm);  group_norm = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu, p_getattr_l__self___down_layers_0___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu = p_getattr_l__self___down_layers_0___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_1, 8, p_getattr_l__self___down_layers_0___1___norm2_weight, p_getattr_l__self___down_layers_0___1___norm2_bias);  conv3d_1 = p_getattr_l__self___down_layers_0___1___norm2_weight = p_getattr_l__self___down_layers_0___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_1);  group_norm_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_1, p_getattr_l__self___down_layers_0___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_1 = p_getattr_l__self___down_layers_0___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_2, feature_dropout);  conv3d_2 = feature_dropout = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(add, p_down_layers_1_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_1_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_3, 8, p_getattr_l__self___down_layers_1___1___norm1_weight, p_getattr_l__self___down_layers_1___1___norm1_bias);  p_getattr_l__self___down_layers_1___1___norm1_weight = p_getattr_l__self___down_layers_1___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_2);  group_norm_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_2, p_getattr_l__self___down_layers_1___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_2 = p_getattr_l__self___down_layers_1___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_4, 8, p_getattr_l__self___down_layers_1___1___norm2_weight, p_getattr_l__self___down_layers_1___1___norm2_bias);  conv3d_4 = p_getattr_l__self___down_layers_1___1___norm2_weight = p_getattr_l__self___down_layers_1___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_3);  group_norm_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_3, p_getattr_l__self___down_layers_1___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_3 = p_getattr_l__self___down_layers_1___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_5, conv3d_3);  conv3d_5 = conv3d_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_1, 8, p_getattr_l__self___down_layers_1___2___norm1_weight, p_getattr_l__self___down_layers_1___2___norm1_bias);  p_getattr_l__self___down_layers_1___2___norm1_weight = p_getattr_l__self___down_layers_1___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_4);  group_norm_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_6: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_4, p_getattr_l__self___down_layers_1___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_4 = p_getattr_l__self___down_layers_1___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_6, 8, p_getattr_l__self___down_layers_1___2___norm2_weight, p_getattr_l__self___down_layers_1___2___norm2_bias);  conv3d_6 = p_getattr_l__self___down_layers_1___2___norm2_weight = p_getattr_l__self___down_layers_1___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_5);  group_norm_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_7: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_5, p_getattr_l__self___down_layers_1___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_5 = p_getattr_l__self___down_layers_1___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_7, add_1);  conv3d_7 = add_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_2, p_down_layers_2_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_2_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_8, 8, p_getattr_l__self___down_layers_2___1___norm1_weight, p_getattr_l__self___down_layers_2___1___norm1_bias);  p_getattr_l__self___down_layers_2___1___norm1_weight = p_getattr_l__self___down_layers_2___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_6);  group_norm_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_6, p_getattr_l__self___down_layers_2___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_6 = p_getattr_l__self___down_layers_2___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_9, 8, p_getattr_l__self___down_layers_2___1___norm2_weight, p_getattr_l__self___down_layers_2___1___norm2_bias);  conv3d_9 = p_getattr_l__self___down_layers_2___1___norm2_weight = p_getattr_l__self___down_layers_2___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_7);  group_norm_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_7, p_getattr_l__self___down_layers_2___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_7 = p_getattr_l__self___down_layers_2___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_3: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_10, conv3d_8);  conv3d_10 = conv3d_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_3, 8, p_getattr_l__self___down_layers_2___2___norm1_weight, p_getattr_l__self___down_layers_2___2___norm1_bias);  p_getattr_l__self___down_layers_2___2___norm1_weight = p_getattr_l__self___down_layers_2___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_8);  group_norm_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_11: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_8, p_getattr_l__self___down_layers_2___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_8 = p_getattr_l__self___down_layers_2___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_11, 8, p_getattr_l__self___down_layers_2___2___norm2_weight, p_getattr_l__self___down_layers_2___2___norm2_bias);  conv3d_11 = p_getattr_l__self___down_layers_2___2___norm2_weight = p_getattr_l__self___down_layers_2___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_9);  group_norm_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_12: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_9, p_getattr_l__self___down_layers_2___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_9 = p_getattr_l__self___down_layers_2___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_4: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_12, add_3);  conv3d_12 = add_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_4, p_down_layers_3_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_3_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_13, 8, p_getattr_l__self___down_layers_3___1___norm1_weight, p_getattr_l__self___down_layers_3___1___norm1_bias);  p_getattr_l__self___down_layers_3___1___norm1_weight = p_getattr_l__self___down_layers_3___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_10);  group_norm_10 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_10, p_getattr_l__self___down_layers_3___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_10 = p_getattr_l__self___down_layers_3___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_14, 8, p_getattr_l__self___down_layers_3___1___norm2_weight, p_getattr_l__self___down_layers_3___1___norm2_bias);  conv3d_14 = p_getattr_l__self___down_layers_3___1___norm2_weight = p_getattr_l__self___down_layers_3___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_11);  group_norm_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_11, p_getattr_l__self___down_layers_3___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_11 = p_getattr_l__self___down_layers_3___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_5: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_15, conv3d_13);  conv3d_15 = conv3d_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_5, 8, p_getattr_l__self___down_layers_3___2___norm1_weight, p_getattr_l__self___down_layers_3___2___norm1_bias);  p_getattr_l__self___down_layers_3___2___norm1_weight = p_getattr_l__self___down_layers_3___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_12);  group_norm_12 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_12, p_getattr_l__self___down_layers_3___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_12 = p_getattr_l__self___down_layers_3___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_16, 8, p_getattr_l__self___down_layers_3___2___norm2_weight, p_getattr_l__self___down_layers_3___2___norm2_bias);  conv3d_16 = p_getattr_l__self___down_layers_3___2___norm2_weight = p_getattr_l__self___down_layers_3___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_13);  group_norm_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_13, p_getattr_l__self___down_layers_3___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_13 = p_getattr_l__self___down_layers_3___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_6: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_17, add_5);  conv3d_17 = add_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_6, 8, p_getattr_l__self___down_layers_3___3___norm1_weight, p_getattr_l__self___down_layers_3___3___norm1_bias);  p_getattr_l__self___down_layers_3___3___norm1_weight = p_getattr_l__self___down_layers_3___3___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_14);  group_norm_14 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_18: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_14, p_getattr_l__self___down_layers_3___3___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_14 = p_getattr_l__self___down_layers_3___3___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_18, 8, p_getattr_l__self___down_layers_3___3___norm2_weight, p_getattr_l__self___down_layers_3___3___norm2_bias);  conv3d_18 = p_getattr_l__self___down_layers_3___3___norm2_weight = p_getattr_l__self___down_layers_3___3___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_15);  group_norm_15 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_19: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_15, p_getattr_l__self___down_layers_3___3___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_15 = p_getattr_l__self___down_layers_3___3___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_7: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_19, add_6);  conv3d_19 = add_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_7, 8, p_getattr_l__self___down_layers_3___4___norm1_weight, p_getattr_l__self___down_layers_3___4___norm1_bias);  p_getattr_l__self___down_layers_3___4___norm1_weight = p_getattr_l__self___down_layers_3___4___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_16);  group_norm_16 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_20: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_16, p_getattr_l__self___down_layers_3___4___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_16 = p_getattr_l__self___down_layers_3___4___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_20, 8, p_getattr_l__self___down_layers_3___4___norm2_weight, p_getattr_l__self___down_layers_3___4___norm2_bias);  conv3d_20 = p_getattr_l__self___down_layers_3___4___norm2_weight = p_getattr_l__self___down_layers_3___4___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_17);  group_norm_17 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_21: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_17, p_getattr_l__self___down_layers_3___4___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_17 = p_getattr_l__self___down_layers_3___4___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_8: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_21, add_7);  conv3d_21 = add_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_22: "f32[1, 64, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_8, p_up_samples_0_0_conv_weight);  add_8 = p_up_samples_0_0_conv_weight = None
        upsample_trilinear3d: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_22, None, False, [2.0, 2.0, 2.0]);  conv3d_22 = None
        add_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(upsample_trilinear3d, add_4);  upsample_trilinear3d = add_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_9, 8, p_getattr_l__self___up_layers_0___0___norm1_weight, p_getattr_l__self___up_layers_0___0___norm1_bias);  p_getattr_l__self___up_layers_0___0___norm1_weight = p_getattr_l__self___up_layers_0___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_18);  group_norm_18 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_23: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_18, p_getattr_l__self___up_layers_0___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_18 = p_getattr_l__self___up_layers_0___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_23, 8, p_getattr_l__self___up_layers_0___0___norm2_weight, p_getattr_l__self___up_layers_0___0___norm2_bias);  conv3d_23 = p_getattr_l__self___up_layers_0___0___norm2_weight = p_getattr_l__self___up_layers_0___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_19);  group_norm_19 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_24: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_19, p_getattr_l__self___up_layers_0___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_19 = p_getattr_l__self___up_layers_0___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_24, add_9);  conv3d_24 = add_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_25: "f32[1, 32, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_10, p_up_samples_1_0_conv_weight);  add_10 = p_up_samples_1_0_conv_weight = None
        upsample_trilinear3d_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_25, None, False, [2.0, 2.0, 2.0]);  conv3d_25 = None
        add_11: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_1, add_2);  upsample_trilinear3d_1 = add_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_11, 8, p_getattr_l__self___up_layers_1___0___norm1_weight, p_getattr_l__self___up_layers_1___0___norm1_bias);  p_getattr_l__self___up_layers_1___0___norm1_weight = p_getattr_l__self___up_layers_1___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_20);  group_norm_20 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_26: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_20, p_getattr_l__self___up_layers_1___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_20 = p_getattr_l__self___up_layers_1___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_26, 8, p_getattr_l__self___up_layers_1___0___norm2_weight, p_getattr_l__self___up_layers_1___0___norm2_bias);  conv3d_26 = p_getattr_l__self___up_layers_1___0___norm2_weight = p_getattr_l__self___up_layers_1___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_21);  group_norm_21 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_27: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_21, p_getattr_l__self___up_layers_1___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_21 = p_getattr_l__self___up_layers_1___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_12: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_27, add_11);  conv3d_27 = add_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_28: "f32[1, 16, 112, 112, 64]" = torch.ops.aten.conv3d.default(add_12, p_up_samples_2_0_conv_weight);  add_12 = p_up_samples_2_0_conv_weight = None
        upsample_trilinear3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_28, None, False, [2.0, 2.0, 2.0]);  conv3d_28 = None
        add_13: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_2, add);  upsample_trilinear3d_2 = add = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_13, 8, p_getattr_l__self___up_layers_2___0___norm1_weight, p_getattr_l__self___up_layers_2___0___norm1_bias);  p_getattr_l__self___up_layers_2___0___norm1_weight = p_getattr_l__self___up_layers_2___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_22);  group_norm_22 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_29: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_22, p_getattr_l__self___up_layers_2___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_22 = p_getattr_l__self___up_layers_2___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_29, 8, p_getattr_l__self___up_layers_2___0___norm2_weight, p_getattr_l__self___up_layers_2___0___norm2_bias);  conv3d_29 = p_getattr_l__self___up_layers_2___0___norm2_weight = p_getattr_l__self___up_layers_2___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_23);  group_norm_23 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_30: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_23, p_getattr_l__self___up_layers_2___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_23 = p_getattr_l__self___up_layers_2___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_14: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_30, add_13);  conv3d_30 = add_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:175 in decode, code: x = self.conv_final(x)
        group_norm_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_14, 8, p_conv_final_0_weight, p_conv_final_0_bias);  add_14 = p_conv_final_0_weight = p_conv_final_0_bias = None
        relu_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_24);  group_norm_24 = None
        conv3d_31: "f32[1, 3, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_24, p_conv_final_2_conv_weight, p_conv_final_2_conv_bias);  relu_24 = p_conv_final_2_conv_weight = p_conv_final_2_conv_bias = None
        return (conv3d_31,)

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024 1

It's possible that the upsample op was somehow decomposed by PyTorch. I will look deeper.

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

I also tried to get more info on which part of dynamo/onnxscript might be responsible for this.
If I run

scripted_model = torch.jit.script(model)
print(scripted_model.graph)

I get this:

graph(%self : __torch__.monai.networks.nets.segresnet.SegResNet,
      %x.1 : Tensor):
  %3 : (Tensor, Tensor[]) = prim::CallMethod[name="encode"](%self, %x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:180:20
  %x.5 : Tensor, %down_x.1 : Tensor[] = prim::TupleUnpack(%3)
   = aten::reverse(%down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:181:8
  %x.9 : Tensor = prim::CallMethod[name="decode"](%self, %x.5, %down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:183:12
  return (%x.9)

If I run

gm, _ = torch._dynamo.export(model)(data)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()

I get an error:

Traceback (most recent call last):
  File "/ws/dynamo/0501/export_SegResNet.py", line 31, in <module>
    gm, _ = torch._dynamo.export(model)(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1282, in inner
    dim_constraints.solve()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 1772, in solve
    tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[4]"

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

Thanks for catching this. Very intriguing. Will take a look!

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

cc @xiaowuhu @fatcat-z

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

@yuanyao-nv could you obtain the graph module from torch.export.export and post it here?

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

Yes, thank you

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

That's very strange. If you run torch.onnx.dynamo_export(exported_program, ...), do you get the same graph?

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

@justinchuby Is this the procedure you're suggesting?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo1.onnx')

The exported UpSample module looks about the same as before, still a very big graph.
In addition, the weights in the model appear as extra inputs, giving rise to tens of extra model inputs. Similar to this bug pytorch/pytorch#126071

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

I don't see any resize ops, which is puzzling. Could you share the onnx model? You may remove the weights if it is too big

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

I was expecting to see this function:

def _aten_upsample_output_size(

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

@justinchuby I uploaded the two versions of the model here: https://drive.google.com/drive/folders/1s1lhKRuG6fOZmD4IjZvN_zlWfIxPB_8w?usp=sharing

from onnxscript.

borisfom avatar borisfom commented on June 27, 2024

I have found that in general case, one has to run exported_program.run_decompositions() before applying dynamo_export().
That may in fact fold some operations. @yuanyao-nv can you try that ?

from onnxscript.

justinchuby avatar justinchuby commented on June 27, 2024

Thanks. We will be creating a series of changes to the exporter to support ExportedPrograms properly, including handling of the weights.

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

@borisfom I tried running run_decompositions() but it didn't do anything for this particular subgraph.

    exported_program = torch.export.export(model, args=(data,))
    exported_program.run_decompositions()
    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo.onnx')

from onnxscript.

titaiwangms avatar titaiwangms commented on June 27, 2024

There are two issues here:

  1. Unsupported upsample_trilinear_vec op: #1592
  2. Dynamo forces to decompose upsample related ops for some reasons (related issue: pytorch/pytorch#115883 and pytorch/pytorch#116684). pytorch/pytorch#128259

from onnxscript.

titaiwangms avatar titaiwangms commented on June 27, 2024

Hi @yuanyao-nv,

This one should be fixed when you call torch.onnx.dynamo_export with nn.Module. However, if you call torch.export.export first, it's going to be decomposed to the big subgraph you had. This decomposition is forced by dynamo for some reasons. Feel free to open an issue like pytorch/pytorch#115883.

cc @gramalingam @justinchuby @xadupre This forcing decomposition would need us to maybe rewriting them as patterns. It will come back to us once we rely on torch.export.export.

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

@titaiwangms Thanks for the update.

What's a good way to test it out?
If I rerun the export script in the description using the latest torch nightly build (2.5.0.dev20240617+cu121) I actually hit another error

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 214, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 169, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2462, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 356, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2677, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2793, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 234, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 962, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 941, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1759, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1846, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
    return self_._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 138, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1080, in add
    a, b = _maybe_broadcast(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 419, in _maybe_broadcast
    common_shape = _broadcast_shapes(
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 408, in _broadcast_shapes
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='cuda:0', size=(1, 64, 28, 56, 56, 32),
           grad_fn=<WarnNotImplemented>), FakeTensor(..., device='cuda:0', size=(1, 64, 56, 56, 32),
           grad_fn=<AddBackward0>)), **{}):
Attempting to broadcast a dimension of length 64 at -4! Mismatching argument at index 1 had torch.Size([1, 64, 56, 56, 32]); but expected shape should be broadcastable to [1, 64, 28, 56, 56, 32]

from user code:
   File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 183, in forward
    x = self.decode(x, down_x)
  File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 171, in decode
    x = up(x) + down_x[i + 1]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Do you know why fake tensor is being used in the latest torch version?

I also tried exporting just a nn.Upsample function

def f(x):
    m = torch.nn.Upsample(size=(10), mode='linear')
    return m(x)

x = torch.randn(2, 5, 5)
export_output = torch.onnx.dynamo_export(f, x)
export_output.save('Upsample.onnx')

The exported graph looks reasonable. Is this what you'd expect?
image

from onnxscript.

yuanyao-nv avatar yuanyao-nv commented on June 27, 2024

Filed a separate issue to track the above fake tensor broadcast error pytorch/pytorch#129534

from onnxscript.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.