Giter VIP home page Giter VIP logo

Comments (5)

Eric-mingjie avatar Eric-mingjie commented on May 24, 2024

how Pruning the last conv layer affects the first linear layer of the classifier which is (512 7 7, 4096).

It changed the input dimension of the first fc layer.

how can I prune the input weights of classifier according to the last conv layer.

The input weights of classifier does not change with the last conv layer.

from network-slimming.

Saharkakavand avatar Saharkakavand commented on May 24, 2024

how Pruning the last conv layer affects the first linear layer of the classifier which is (512 7 7, 4096).

It changed the input dimension of the first fc layer.

how can I prune the input weights of classifier according to the last conv layer.

The input weights of classifier does not change with the last conv layer.

nn.Conv2d(512, 512, 3, padding=1), is changed to nn.Conv2d(450, 412, 3, padding=1)
how this one will change: nn.Linear(in_features=512 * 7 * 7, out_features=4096, bias=True),

class VGGOWN(nn.Module):

def __init__(self):
    super(VGGOWN, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 64, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, stride=2, ceil_mode=True),  # 1/2

        # conv2
        nn.Conv2d(64, 128, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 128, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, stride=2, ceil_mode=True),  # 1/4

        # conv3
        nn.Conv2d(128, 256, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(256, 256, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(256, 256, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8

        # conv4
        nn.Conv2d(256, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(512, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(512, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, stride=2, ceil_mode=True),  # 1/16

        # conv5
        nn.Conv2d(512, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(512, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(512, 512, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
    )
    self.avgpool = nn.AdaptiveAvgPool2d((7, 7))

    self.classifier = nn.Sequential(

        nn.Linear(in_features=512 * 7 * 7, out_features=4096, bias=True),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=4096, out_features=4096, bias=True),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=4096, out_features=1000, bias=True)
    )

def forward(self, *input):
    x = self.features(input)
    #print(x.size)
    x = self.avgpool(x)
    #print(x.size)
    x = x.view(x.size(0), -1)
    #print(x.size)
    y = self.classifier(x)
    return y

thanks

from network-slimming.

Eric-mingjie avatar Eric-mingjie commented on May 24, 2024

Should be nn.Linear(in_features=450 * 7 * 7, out_features=4096, bias=True) now if i understand the definition of nn.conv2d correctly.

from network-slimming.

Saharkakavand avatar Saharkakavand commented on May 24, 2024

when I load the pretrained state-dict it has already (25088, 4096) weights in linear layer how can I know which one should be prun?
I know which filters I pruned in prev layer I dont know how conv layer weights is mapped to next layer

from network-slimming.

Eric-mingjie avatar Eric-mingjie commented on May 24, 2024

Please try to understand vggprune.py in detail. Then you will know.

from network-slimming.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.