Giter VIP home page Giter VIP logo

Comments (5)

Light-- avatar Light-- commented on September 24, 2024 1

@chesharma

vis_graph = make_dot((y[0], y[1], y[2]), params=dict(list(self.model.named_parameters()) ))

Genius bro! ๐Ÿ‘ ๐Ÿ‘ ๐Ÿ‘ How did you notice this problem and figure it out ?

# the output of my model is a list, and its length is 40, i used this and it worked out!
vis_graph = make_dot(tuple((y[i] for i in range(40))),)

thanks @chesharma

from pytorchviz.

szagoruyko avatar szagoruyko commented on September 24, 2024

does your model output a list?

from pytorchviz.

H-YunHui avatar H-YunHui commented on September 24, 2024

@caffelearn @szagoruyko
I also encountered this problem, did you solve it๏ผŸ

from pytorchviz.

Light-- avatar Light-- commented on September 24, 2024

does your model output a list?

i also meet this problem and yes, my model output a list

could you help me? @szagoruyko

class Backbone(Module):
    def __init__(self, num_layers, drop_ratio, mode='ir'):
        super(Backbone, self).__init__()
        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False), 
                                      BatchNorm2d(64), 
                                      PReLU(64))
        self.output_layer = Sequential(BatchNorm2d(512), 
                                       Dropout(drop_ratio),
                                       Flatten(),
                                       Linear(512 * 7 * 7, 512),
                                       BatchNorm1d(512))
        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(
                    unit_module(bottleneck.in_channel,
                                bottleneck.depth,
                                bottleneck.stride))
        self.body = Sequential(*modules)

        # for MTL
        self.tower = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

        self.towers = nn.ModuleList([self.tower for _ in range(40)])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self,x):
        x = self.input_layer(x)
        x = self.body(x)
        h_shared = self.output_layer(x)
        # for MTL
        out = [tower(h_shared) for tower in self.towers]
        return out

how do you solve this? @caffelearn @H-YunHui

from pytorchviz.

 avatar commented on September 24, 2024

@Light-- I was able to resolve this by passing a tuple containing the output list's elements.

For example if your model has 3 outputs which you output as elements of a list called 'y', then the make_dot function would look like this:

vis_graph = make_dot((y[0], y[1], y[2]), params=dict(list(self.model.named_parameters()) ))

from pytorchviz.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.