Giter VIP home page Giter VIP logo

Comments (9)

fengyuentau avatar fengyuentau commented on July 18, 2024

After some testing, I found only the model definition in __init__ (it doesn't matter whether the definition is used or not in forward) will be added hooks and profiled, which introduces a tricky problem for APIs like nn.functional.interpolate who needs the size of output to be passed when initializing, but the size of output can only be revealed as the tensor goes through the model.

from pytorch-opcounter.

fengyuentau avatar fengyuentau commented on July 18, 2024

A post from PyTorch forum says,

model.apply(fn) combined with module.register_forward_hook(hook) allows for easy tracking of layers, but only works for nn.Module layers (conv, batchnorm, etc). This is effective for the majority of cases, but does not allow for tracking of functional calls, e.g. F.interpolate(...). Is there any way to detect functional calls in a forward pass?

It seems currently there is no solution for setting up FLOPs counter for interpolate.

from pytorch-opcounter.

Lyken17 avatar Lyken17 commented on July 18, 2024

That's true. Currently, THOP is based on module-level hooks. Can you try nn.upsampling instead?

from pytorch-opcounter.

Lyken17 avatar Lyken17 commented on July 18, 2024

I have implemented a version count_hooks.py#L95.

from pytorch-opcounter.

fengyuentau avatar fengyuentau commented on July 18, 2024

@Lyken17 Thanks for your implementation. The most difficult problem is I cannot specify the output size when the model is initializing since multi-scale testing is being popularly applied and the input size may not be fixed.

from pytorch-opcounter.

Lyken17 avatar Lyken17 commented on July 18, 2024

My current implementation is based on hooks during the forwarding, when the output size should be determined.

from pytorch-opcounter.

fengyuentau avatar fengyuentau commented on July 18, 2024

I realized that I can set the scale_factor instead of size for nn.Upsample. Also I made a mistake: I thought nn.Upsample had been deprecated, but the actually deprecated one should be nn.functional.upsample. Thanks for your help.

from pytorch-opcounter.

fengyuentau avatar fengyuentau commented on July 18, 2024

A potential and tricky problem: If scale_factor is set for nn.Upsample, such as scale_factor=2, and the height/width is odd, the upsampled size would not match the original size. For example, given an input of size 41x41, the input is downsampled by Conv2d which outputs feature maps of size 21x21. Upsampling a 21x21 tensor by scale_factor=2 outputs a tensor of size 42x42 which does not match with the original one.

However, to register a forward hook for nn.Upsample, nn.Upsample must be defined in the __init__ function of the model. It becomes tricky if the input size is not fixed and you can only set scale_factor to instantiate a nn.Upsample in __init__.

from pytorch-opcounter.

fengyuentau avatar fengyuentau commented on July 18, 2024

A workaround: set the input size to be the multiple of the downsample factor. For example, ResNet downsamples an input by a factor of 32, meaning if one wants to apply upsample to the tensor of size 32 times smaller, the input size must be the multiple of 32 to keep the sizes of tensors matched. Or if the upsample is applied to the tensor of size 16 times smaller, then the input size should be the multiple of 16.

from pytorch-opcounter.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.