Giter VIP home page Giter VIP logo

pytorch-cnn-finetune's Introduction

Fine-tune pretrained Convolutional Neural Networks with PyTorch.

PyPI CircleCI codecov.io

Features

  • Gives access to the most popular CNN architectures pretrained on ImageNet.
  • Automatically replaces classifier on top of the network, which allows you to train a network with a dataset that has a different number of classes.
  • Allows you to use images with any resolution (and not only the resolution that was used for training the original model on ImageNet).
  • Allows adding a Dropout layer or a custom pooling layer.

Supported architectures and models

From the torchvision package:

  • ResNet (resnet18, resnet34, resnet50, resnet101, resnet152)
  • ResNeXt (resnext50_32x4d, resnext101_32x8d)
  • DenseNet (densenet121, densenet169, densenet201, densenet161)
  • Inception v3 (inception_v3)
  • VGG (vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn)
  • SqueezeNet (squeezenet1_0, squeezenet1_1)
  • MobileNet V2 (mobilenet_v2)
  • ShuffleNet v2 (shufflenet_v2_x0_5, shufflenet_v2_x1_0)
  • AlexNet (alexnet)
  • GoogLeNet (googlenet)
  • ResNeXt (resnext101_32x4d, resnext101_64x4d)
  • NASNet-A Large (nasnetalarge)
  • NASNet-A Mobile (nasnetamobile)
  • Inception-ResNet v2 (inceptionresnetv2)
  • Dual Path Networks (dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107)
  • Inception v4 (inception_v4)
  • Xception (xception)
  • Squeeze-and-Excitation Networks (senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d)
  • PNASNet-5-Large (pnasnet5large)
  • PolyNet (polynet)

Requirements

  • Python 3.5+
  • PyTorch 1.1+

Installation

pip install cnn_finetune

Major changes:

Version 0.4

  • Default value for pretrained argument in make_model is changed from False to True. Now call make_model('resnet18', num_classes=10) is equal to make_model('resnet18', num_classes=10, pretrained=True)

Example usage:

Make a model with ImageNet weights for 10 classes

from cnn_finetune import make_model

model = make_model('resnet18', num_classes=10, pretrained=True)

Make a model with Dropout

model = make_model('nasnetalarge', num_classes=10, pretrained=True, dropout_p=0.5)

Make a model with Global Max Pooling instead of Global Average Pooling

import torch.nn as nn

model = make_model('inceptionresnetv2', num_classes=10, pretrained=True, pool=nn.AdaptiveMaxPool2d(1))

Make a VGG16 model that takes images of size 256x256 pixels

VGG and AlexNet models use fully-connected layers, so you have to additionally pass the input size of images when constructing a new model. This information is needed to determine the input size of fully-connected layers.

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256))

Make a VGG16 model that takes images of size 256x256 pixels and uses a custom classifier

import torch.nn as nn

def make_classifier(in_features, num_classes):
    return nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, num_classes),
    )

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)

Show preprocessing that was used to train the original model on ImageNet

>> model = make_model('resnext101_64x4d', num_classes=10, pretrained=True)
>> print(model.original_model_info)
ModelInfo(input_space='RGB', input_size=[3, 224, 224], input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
>> print(model.original_model_info.mean)
[0.485, 0.456, 0.406]

CIFAR10 Example

See examples/cifar10.py file (requires PyTorch 1.1+).

pytorch-cnn-finetune's People

Contributors

cgnorthcutt avatar creafz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-cnn-finetune's Issues

Access to fc layer

I need to access the fc layer in the ResNet model. With the standard ResNet, this is done using
model.fc.weight
However, with the ones from this repo, I am getting this error:

In [33]: model.fc.weight

raceback (most recent call last):


File "<ipython-input-33-ca343f58d0f7>", line 1, in <module>
  model.fc

File "//anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 535, in __getattr__
  type(self).__name__, name))

AttributeError: 'ResNetWrapper' object has no attribute 'fc'

Is there a way around this?

Kernel size issue with "inceptionresnetv2"

When I am trying to run examples/cifar10.py with Inception-Resnet-v2, i am getting the following error:

RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

Tried to run this:

import argparse
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from cnn_finetune import make_model

parser = argparse.ArgumentParser(description='Inception-Resnet-v2-TRAIN')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                    help='input batch size for testing (default: 64)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--save-model', type=int, default=10, metavar='N',
                    help='number of epochs after which the model will be saved (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model-name', type=str, default='resnet50', metavar='M',
                    help='model name (default: resnet50)')
parser.add_argument('--dropout-p', type=float, default=0.2, metavar='D',
                    help='Dropout probability (default: 0.2)')

args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')


def train(model, epoch, optimizer, train_loader, criterion=nn.CrossEntropyLoss()):
    total_loss = 0
    total_size = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()
        total_size += data.size(0)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), total_loss / total_size))


def main():
    model_name = args.model_name

    classes = (
        'plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    )

    model = make_model(
        model_name,
        pretrained=True,
        num_classes=len(classes),
        pool=nn.AdaptiveMaxPool2d(1),
        dropout_p=args.dropout_p
    )
    model = model.to(device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=model.original_model_info.mean,
            std=model.original_model_info.std),
    ])

    train_set = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True, num_workers=2
    )

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    # Use exponential decay for fine-tuning optimizer
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.975)

    # Train
    for epoch in range(1, args.epochs + 1):
        train(model, epoch, optimizer, train_loader)
        scheduler.step(epoch)
        if epoch % args.save_model == 0:
            torch.save(model.state_dict(), './checkpoint/' + 'ckpt_' + str(epoch) + '.pth')


if __name__ == '__main__':
    main()

save mode

When saving the resnext101_64x4d model, the following error occurred. There is no such problem in saving other models. What is the reason?

Traceback (most recent call last):
File "resnext101_64x4d_CellData.py", line 251, in
main()
File "resnext101_64x4d_CellData.py", line 243, in main
max_acc = test(model, test_loader,max_acc,epoch_test)
File "resnext101_64x4d_CellData.py", line 130, in test
torch.save(model,'./models/%s.pth'%args.model_name)
File "/home/zlw/.local/lib/python3.5/site-packages/torch/serialization.py", line 260, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/zlw/.local/lib/python3.5/site-packages/torch/serialization.py", line 185, in _with_file_like
return body(f)
File "/home/zlw/.local/lib/python3.5/site-packages/torch/serialization.py", line 260, in
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/zlw/.local/lib/python3.5/site-packages/torch/serialization.py", line 332, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <function at 0x7ff02de7b0d0>: attribute lookup on pretrainedmodels.models.resnext_features.resnext101_64x4d_features failed

how can i load the trained-model

hello,
i finetune the pre-trained inception_v4 model with 3 classes,by model=make_model('inception_v4',3,pretrained=True)
, and saved the trained-model as 'model.pth' by torch.save(model.state_dict(),'model.pth')
but when i restored the model by
model = make_model('inception_v4',3,pretrained=False)
model.load_state_dict(torch.load('model.pth'))
it goes wrong,and the error is :
Unexpected key(s) in state_dict: "model._features.0.conv.weight", "model._features.0.bn.weight", "model._features.0.bn.bias", "model._features.0.bn.running_mean", "model._features.0.bn.running_var", "model._features.0.bn.num_batches_tracked", "model._features.1.conv.weight", "model._features.1.bn.weight", "model._features.1.bn.bias", "model._features.1.bn.running_mean", "model._features.1.bn.running_var", "model._features.1.bn.num_batches_tracked", "model._features.2.conv.weight", "model._features.2.bn.weight", "model._features.2.bn.bias", "model._features.2.bn.running_mean", "model._features.2.bn.running_var", "model._features.2.bn.num_batches_tracked", "model._features.3.conv.conv.weight", "model._features.3.conv.bn.weight", "model._features.3.conv.bn.bias", "model._features.3.conv.bn.running_mean", "model._features.3.conv.bn.running_var", "model._features.3.conv.bn.num_batches_tracked", "model._features.4.branch0.0.conv.weight", "model._features.4.branch0.0.bn.weight", "model._features.4.branch0.0.bn.bias", "model._features.4.branch0.0.bn.running_mean", "model._features.4.branch0.0.bn.running_var", "model._features.4.branch0.0.bn.num_batches_tracked", "model._features.4.branch0.1.conv.weight", "model._features.4.branch0.1.bn.weight", "model._features.4.branch0.1.bn.bias", "model._features.4.branch0.1.bn.running_mean", "model._features.4.branch0.1.bn.running_var", "model._features.4.branch0.1.bn.num_batches_tracked", "model._features.4.branch1.0.conv.weight", "model._features.4.branch1.0.bn.weight", "model._features.4.branch1.0.bn.bias", "model._features.4.branch1.0.bn.running_mean", "model._features.4.branch1.0.bn.running_var", "model._features.4.branch1.0.bn.num_batches_tracked", "model._features.4.branch1.1.conv.weight", "model._features.4.branch1.1.bn.weight", "model._features.4.branch1.1.bn.bias", "model._features.4.branch1.1.bn.running_mean", "model._features.4.branch1.1.bn.running_var", "model._features.4.branch1.1.bn.num_batches_tracked", "model._features.4.branch1.2.conv.weight", "model._features.4.branch1.2.bn.weight", "model._features.4.branch1.2.bn.bias", "model._features.4.branch1.2.bn.running_mean", "model._features.4.branch1.2.bn.running_var", "model._features.4.branch1.2.bn.num_batches_tracked", "model._features.4.branch1.3.conv.weight", "model._features.4.branch1.3.bn.weight", "model._features.4.branch1.3.bn.bias", "model._features.4.branch1.3.bn.running_mean", "model._features.4.branch1.3.bn.running_var", "model._features.4.branch1.3.bn.num_batches_tracked", "model._features.5.conv.conv.weight", "model._features.5.conv.bn.weight", "model._features.5.conv.bn.bias", "model._features.5.conv.bn.running_mean", "model._features.5.conv.bn.running_var", "model._features.5.conv.bn.num_batches_tracked", "model._features.6.branch0.conv.weight", "model._features.6.branch0.bn.weight", "model._features.6.branch0.bn.bias", "model._features.6.branch0.bn.running_mean", "model._features.6.branch0.bn.running_var", "model._features.6.branch0.bn.num_batches_tracked", "model._features.6.branch1.0.conv.weight", "model._features.6.branch1.0.bn.weight", "model._features.6.branch1.0.bn.bias", "model._features.6.branch1.0.bn.running_mean", "model._features.6.branch1.0.bn.running_var", "model._features.6.branch1.0.bn.num_batches_tracked", "model._features.6.branch1.1.conv.weight", "model._features.6.branch1.1.bn.weight", "model._features.6.branch1.1.bn.bias", "model._features.6.branch1.1.bn.running_mean", "model._features.6.branch1.1.bn.running_var", "model._features.6.branch1.1.bn.num_batches_tracked", "model._features.6.branch2.0.conv.weight", "model._features.6.branch2.0.bn.weight", "model._features.6.branch2.0.bn.bias", "model._features.6.branch2.0.bn.running_mean", "model._features.6.branch2.0.bn.running_var", "model._features.6.branch2.0.bn.num_batches_tracked", "model._features.6.branch2.1.conv.weight", "model._features.6.branch2.1.bn.weight", "model._features.6.branch2.1.bn.bias", "model._features.6.branch2.1.bn.running_mean", "model._features.6.branch2.1.bn.running_var", "model._features.6.branch2.1.bn.num_batches_tracked", "model._features.6.branch2.2.conv.weight", "model._features.6.branch2.2.bn.weight", "model._features.6.branch2.2.bn.bias", "model._features.6.branch2.2.bn.running_mean", "model._features.6.branch2.2.bn.running_var", "model._features.6.branch2.2.bn.num_batches_tracked", "model._features.6.branch3.1.conv.weight", "model._features.6.branch3.1.bn.weight", "model._features.6.branch3.1.bn.bias", "model._features.6.branch3.1.bn.running_mean", "model._features.6.branch3.1.bn.running_var", "model._features.6.branch3.1.bn.num_batches_tracked", "model._features.7.branch0.conv.weight", "model._features.7.branch0.bn.weight", "model._features.7.branch0.bn.bias", "model._features.7.branch0.bn.running_mean", "model._features.7.branch0.bn.running_var", "model._features.7.branch0.bn.num_batches_tracked", "model._features.7.branch1.0.conv.weight", "model._features.7.branch1.0.bn.weight", "model._features.7.branch1.0.bn.bias", "model._features.7.branch1.0.bn.running_mean", "model._features.7.branch1.0.bn.running_var", "model._features.7.branch1.0.bn.num_batches_tracked", "model._features.7.branch1.1.conv.weight", "model._features.7.branch1.1.bn.weight", "model._features.7.branch1.1.bn.bias", "model._features.7.branch1.1.bn.running_mean", "model._features.7.branch1.1.bn.running_var", "model._features.7.branch1.1.bn.num_batches_tracked", "model._features.7.branch2.0.conv.weight", "model._features.7.branch2.0.bn.weight", "model._features.7.branch2.0.bn.bias", "model._features.7.branch2.0.bn.running_mean", "model._features.7.branch2.0.bn.running_var", "model._features.7.branch2.0.bn.num_batches_tracked", "model._features.7.branch2.1.conv.weight", "model._features.7.branch2.1.bn.weight", "model._features.7.branch2.1.bn.bias", "model._features.7.branch2.1.bn.running_mean", "model._features.7.branch2.1.bn.running_var", "model._features.7.branch2.1.bn.num_batches_tracked", "model._features.7.branch2.2.conv.weight", "model._features.7.branch2.2.bn.weight", "model._features.7.branch2.2.bn.bias", "model._features.7.branch2.2.bn.running_mean", "model._features.7.branch2.2.bn.running_var", "model._features.7.branch2.2.bn.num_batches_tracked", "model._features.7.branch3.1.conv.weight", "model._features.7.branch3.1.bn.weight", "model._features.7.branch3.1.bn.bias", "model._features.7.branch3.1.bn.running_mean", "model._features.7.branch3.1.bn.running_var", "model._features.7.branch3.1.bn.num_batches_tracked", "model._features.8.branch0.conv.weight", "model._features.8.branch0.bn.weight", "model._features.8.branch0.bn.bias", "model._features.8.branch0.bn.running_mean", "model._features.8.branch0.bn.running_var", "model._features.8.branch0.bn.num_batches_tracked", "model._features.8.branch1.0.conv.weight", "model._features.8.branch1.0.bn.weight", "model._features.8.branch1.0.bn.bias", "model._features.8.branch1.0.bn.running_mean", "model._features.8.branch1.0.bn.running_var", "model._features.8.branch1.0.bn.num_batches_tracked", "model._features.8.branch1.1.conv.weight", "model._features.8.branch1.1.bn.weight", "model._features.8.branch1.1.bn.bias", "model._features.8.branch1.1.bn.running_mean", "model._features.8.branch1.1.bn.running_var", "model._features.8.branch1.1.bn.num_batches_tracked", "model._features.8.branch2.0.conv.weight", "model._features.8.branch2.0.bn.weight", "model._features.8.branch2.0.bn.bias", "model._features.8.branch2.0.bn.running_mean", "model._features.8.branch2.0.bn.running_var", "model._features.8.branch2.0.bn.num_batches_tracked", "model._features.8.branch2.1.conv.weight", "model._features.8.branch2.1.bn.weight", "model._features.8.branch2.1.bn.bias", "model._features.8.branch2.1.bn.running_mean", "model._features.8.branch2.1.bn.running_var", "model._features.8.branch2.1.bn.num_batches_tracked", "model._features.8.branch2.2.conv.weight", "model._features.8.branch2.2.bn.weight", "model._features.8.branch2.2.bn.bias", "model._features.8.branch2.2.bn.running_mean", "model._features.8.branch2.2.bn.running_var", "model._features.8.branch2.2.bn.num_batches_tracked", "model._features.8.branch3.1.conv.weight", "model._features.8.branch3.1.bn.weight", "model._features.8.branch3.1.bn.bias", "model._features.8.branch3.1.bn.running_mean", "model._features.8.branch3.1.bn.running_var", "model._features.8.branch3.1.bn.num_batches_tracked", "model._features.9.branch0.conv.weight", "model._features.9.branch0.bn.weight", "model._features.9.branch0.bn.bias", "model._features.9.branch0.bn.running_mean", "model._features.9.branch0.bn.running_var", "model._features.9.branch0.bn.num_batches_tracked", "model._features.9.branch1.0.conv.weight", "model._features.9.branch1.0.bn.weight", "model._features.9.branch1.0.bn.bias", "model._features.9.branch1.0.bn.running_mean", "model._features.9.branch1.0.bn.running_var", "model._features.9.branch1.0.bn.num_batches_tracked", "model._features.9.branch1.1.conv.weight", "model._features.9.branch1.1.bn.weight", "model._features.9.branch1.1.bn.bias", "model._features.9.branch1.1.bn.running_mean", "model._features.9.branch1.1.bn.running_var", "model._features.9.branch1.1.bn.num_batches_tracked", "model._features.9.branch2.0.conv.weight", "model._features.9.branch2.0.bn.weight", "model._features.9.branch2.0.bn.bias", "model._features.9.branch2.0.bn.running_mean", "model._features.9.branch2.0.bn.running_var", "model._features.9.branch2.0.bn.num_batches_tracked", "model._features.9.branch2.1.conv.weight", "model._features.9.branch2.1.bn.weight", "model._features.9.branch2.1.bn.bias", "model._features.9.branch2.1.bn.running_mean", "model._features.9.branch2.1.bn.running_var", "model._features.9.branch2.1.bn.num_batches_tracked", "model._features.9.branch2.2.conv.weight", "model._features.9.branch2.2.bn.weight", "model._features.9.branch2.2.bn.bias", "model._features.9.branch2.2.bn.running_mean", "model._features.9.branch2.2.bn.running_var", "model._features.9.branch2.2.bn.num_batches_tracked", "model._features.9.branch3.1.conv.weight", "model._features.9.branch3.1.bn.weight", "model._features.9.branch3.1.bn.bias", "model._features.9.branch3.1.bn.running_mean", "model._features.9.branch3.1.bn.running_var", "model._features.9.branch3.1.bn.num_batches_tracked", "model._features.10.branch0.conv.weight", "model._features.10.branch0.bn.weight", "model._features.10.branch0.bn.bias", "model._features.10.branch0.bn.running_mean", "model._features.10.branch0.bn.running_var", "model._features.10.branch0.bn.num_batches_tracked", "model._features.10.branch1.0.conv.weight", "model._features.10.branch1.0.bn.weight", "model._features.10.branch1.0.bn.bias", "model._features.10.branch1.0.bn.running_mean", "model._features.10.branch1.0.bn.running_var", "model._features.10.branch1.0.bn.num_batches_tracked", "model._features.10.branch1.1.conv.weight", "model._features.10.branch1.1.bn.weight", "model._features.10.branch1.1.bn.bias", "model._features.10.branch1.1.bn.running_mean", "model._features.10.branch1.1.bn.running_var", "model._features.10.branch1.1.bn.num_batches_tracked", "model._features.10.branch1.2.conv.weight", "model._features.10.branch1.2.bn.weight", "model._features.10.branch1.2.bn.bias", "model._features.10.branch1.2.bn.running_mean", "model._features.10.branch1.2.bn.running_var", "model._features.10.branch1.2.bn.num_batches_tracked", "model._features.11.branch0.conv.weight", "model._features.11.branch0.bn.weight", "model._features.11.branch0.bn.bias", "model._features.11.branch0.bn.running_mean", "model._features.11.branch0.bn.running_var", "model._features.11.branch0.bn.num_batches_tracked", "model._features.11.branch1.0.conv.weight", "model._features.11.branch1.0.bn.weight", "model._features.11.branch1.0.bn.bias", "model._features.11.branch1.0.bn.running_mean", "model._features.11.branch1.0.bn.running_var", "model._features.11.branch1.0.bn.num_batches_tracked", "model._features.11.branch1.1.conv.weight", "model._features.11.branch1.1.bn.weight", "model._features.11.branch1.1.bn.bias", "model._features.11.branch1.1.bn.running_mean", "model._features.11.branch1.1.bn.running_var", "model._features.11.branch1.1.bn.num_batches_tracked", "model._features.11.branch1.2.conv.weight", "model._features.11.branch1.2.bn.weight", "model._features.11.branch1.2.bn.bias", "model._features.11.branch1.2.bn.running_mean", "model._features.11.branch1.2.bn.running_var", "model._features.11.branch1.2.bn.num_batches_tracked", "model._features.11.branch2.0.conv.weight", "model._features.11.branch2.0.bn.weight", "model._features.11.branch2.0.bn.bias", "model._features.11.branch2.0.bn.running_mean", "model._features.11.branch2.0.bn.running_var", "model._features.11.branch2.0.bn.num_batches_tracked", "model._features.11.branch2.1.conv.weight", "model._features.11.branch2.1.bn.weight", "model._features.11.branch2.1.bn.bias", "model._features.11.branch2.1.bn.running_mean", "model._features.11.branch2.1.bn.running_var", "model._features.11.branch2.1.bn.num_batches_tracked", "model._features.11.branch2.2.conv.weight", "model._features.11.branch2.2.bn.weight", "model._features.11.branch2.2.bn.bias", "model._features.11.branch2.2.bn.running_mean", "model._features.11.branch2.2.bn.running_var", "model._features.11.branch2.2.bn.num_batches_tracked", "model._features.11.branch2.3.conv.weight", "model._features.11.branch2.3.bn.weight", "model._features.11.branch2.3.bn.bias", "model._features.11.branch2.3.bn.running_mean", "model._features.11.branch2.3.bn.running_var", "model._features.11.branch2.3.bn.num_batches_tracked", "model._features.11.branch2.4.conv.weight", "model._features.11.branch2.4.bn.weight", "model._features.11.branch2.4.bn.bias", "model._features.11.branch2.4.bn.running_mean", "model._features.11.branch2.4.bn.running_var", "model._features.11.branch2.4.bn.num_batches_tracked", "model._features.11.branch3.1.conv.weight", "model._features.11.branch3.1.bn.weight", "model._features.11.branch3.1.bn.bias", "model._features.11.branch3.1.bn.running_mean", "model._features.11.branch3.1.bn.running_var", "model._features.11.branch3.1.bn.num_batches_tracked", "model._features.12.branch0.conv.weight", "model._features.12.branch0.bn.weight", "model._features.12.branch0.bn.bias", "model._features.12.branch0.bn.running_mean", "model._features.12.branch0.bn.running_var", "model._features.12.branch0.bn.num_batches_tracked", "model._features.12.branch1.0.conv.weight", "model._features.12.branch1.0.bn.weight", "model._features.12.branch1.0.bn.bias", "model._features.12.branch1.0.bn.running_mean", "model._features.12.branch1.0.bn.running_var", "model._features.12.branch1.0.bn.num_batches_tracked", "model._features.12.branch1.1.conv.weight", "model._features.12.branch1.1.bn.weight", "model._features.12.branch1.1.bn.bias", "model._features.12.branch1.1.bn.running_mean", "model._features.12.branch1.1.bn.running_var", "model._features.12.branch1.1.bn.num_batches_tracked", "model._features.12.branch1.2.conv.weight", "model._features.12.branch1.2.bn.weight", "model._features.12.branch1.2.bn.bias", "model._features.12.branch1.2.bn.running_mean", "model._features.12.branch1.2.bn.running_var", "model._features.12.branch1.2.bn.num_batches_tracked", "model._features.12.branch2.0.conv.weight", "model._features.12.branch2.0.bn.weight", "model._features.12.branch2.0.bn.bias", "model._features.12.branch2.0.bn.running_mean", "model._features.12.branch2.0.bn.running_var", "model._features.12.branch2.0.bn.num_batches_tracked", "model._features.12.branch2.1.conv.weight", "model._features.12.branch2.1.bn.weight", "model._features.12.branch2.1.bn.bias", "model._features.12.branch2.1.bn.running_mean", "model._features.12.branch2.1.bn.running_var", "model._features.12.branch2.1.bn.num_batches_tracked", "model._features.12.branch2.2.conv.weight", "model._features.12.branch2.2.bn.weight", "model._features.12.branch2.2.bn.bias", "model._features.12.branch2.2.bn.running_mean", "model._features.12.branch2.2.bn.running_var", "model._features.12.branch2.2.bn.num_batches_tracked", "model._features.12.branch2.3.conv.weight", "model._features.12.branch2.3.bn.weight", "model._features.12.branch2.3.bn.bias", "model._features.12.branch2.3.bn.running_mean", "model._features.12.branch2.3.bn.running_var", "model._features.12.branch2.3.bn.num_batches_tracked", "model._features.12.branch2.4.conv.weight", "model._features.12.branch2.4.bn.weight", "model._features.12.branch2.4.bn.bias", "model._features.12.branch2.4.bn.running_mean", "model._features.12.branch2.4.bn.running_var", "model._features.12.branch2.4.bn.num_batches_tracked", "model._features.12.branch3.1.conv.weight", "model._features.12.branch3.1.bn.weight", "model._features.12.branch3.1.bn.bias", "model._features.12.branch3.1.bn.running_mean", "model._features.12.branch3.1.bn.running_var", "model._features.12.branch3.1.bn.num_batches_tracked", "model._features.13.branch0.conv.weight", "model._features.13.branch0.bn.weight", "model._features.13.branch0.bn.bias", "model._features.13.branch0.bn.running_mean", "model._features.13.branch0.bn.running_var", "model._features.13.branch0.bn.num_batches_tracked", "model._features.13.branch1.0.conv.weight", "model._features.13.branch1.0.bn.weight", "model._features.13.branch1.0.bn.bias", "model._features.13.branch1.0.bn.running_mean", "model._features.13.branch1.0.bn.running_var", "model._features.13.branch1.0.bn.num_batches_tracked", "model._features.13.branch1.1.conv.weight", "model._features.13.branch1.1.bn.weight", "model._features.13.branch1.1.bn.bias", "model._features.13.branch1.1.bn.running_mean", "model._features.13.branch1.1.bn.running_var", "model._features.13.branch1.1.bn.num_batches_tracked", "model._features.13.branch1.2.conv.weight", "model._features.13.branch1.2.bn.weight", "model._features.13.branch1.2.bn.bias", "model._features.13.branch1.2.bn.running_mean", "model._features.13.branch1.2.bn.running_var", "model._features.13.branch1.2.bn.num_batches_tracked", "model._features.13.branch2.0.conv.weight", "model._features.13.branch2.0.bn.weight", "model._features.13.branch2.0.bn.bias", "model._features.13.branch2.0.bn.running_mean", "model._features.13.branch2.0.bn.running_var", "model._features.13.branch2.0.bn.num_batches_tracked", "model._features.13.branch2.1.conv.weight", "model._features.13.branch2.1.bn.weight", "model._features.13.branch2.1.bn.bias", "model._features.13.branch2.1.bn.running_mean", "model._features.13.branch2.1.bn.running_var", "model._features.13.branch2.1.bn.num_batches_tracked", "model._features.13.branch2.2.conv.weight", "model._features.13.branch2.2.bn.weight", "model._features.13.branch2.2.bn.bias", "model._features.13.branch2.2.bn.running_mean", "model._features.13.branch2.2.bn.running_var", "model._features.13.branch2.2.bn.num_batches_tracked", "model._features.13.branch2.3.conv.weight", "model._features.13.branch2.3.bn.weight", "model._features.13.branch2.3.bn.bias", "model._features.13.branch2.3.bn.running_mean", "model._features.13.branch2.3.bn.running_var", "model._features.13.branch2.3.bn.num_batches_tracked", "model._features.13.branch2.4.conv.weight", "model._features.13.branch2.4.bn.weight", "model._features.13.branch2.4.bn.bias", "model._features.13.branch2.4.bn.running_mean", "model._features.13.branch2.4.bn.running_var", "model._features.13.branch2.4.bn.num_batches_tracked", "model._features.13.branch3.1.conv.weight", "model._features.13.branch3.1.bn.weight", "model._features.13.branch3.1.bn.bias", "model._features.13.branch3.1.bn.running_mean", "model._features.13.branch3.1.bn.running_var", "model._features.13.branch3.1.bn.num_batches_tracked", "model._features.14.branch0.conv.weight", "model._features.14.branch0.bn.weight", "model._features.14.branch0.bn.bias", "model._features.14.branch0.bn.running_mean", "model._features.14.branch0.bn.running_var", "model._features.14.branch0.bn.num_batches_tracked", "model._features.14.branch1.0.conv.weight", "model._features.14.branch1.0.bn.weight", "model._features.14.branch1.0.bn.bias", "model._features.14.branch1.0.bn.running_mean", "model._features.14.branch1.0.bn.running_var", "model._features.14.branch1.0.bn.num_batches_tracked", "model._features.14.branch1.1.conv.weight", "model._features.14.branch1.1.bn.weight", "model._features.14.branch1.1.bn.bias", "model._features.14.branch1.1.bn.running_mean", "model._features.14.branch1.1.bn.running_var", "model._features.14.branch1.1.bn.num_batches_tracked", "model._features.14.branch1.2.conv.weight", "model._features.14.branch1.2.bn.weight", "model._features.14.branch1.2.bn.bias", "model._features.14.branch1.2.bn.running_mean", "model._features.14.branch1.2.bn.running_var", "model._features.14.branch1.2.bn.num_batches_tracked", "model._features.14.branch2.0.conv.weight", "model._features.14.branch2.0.bn.weight", "model._features.14.branch2.0.bn.bias", "model._features.14.branch2.0.bn.running_mean", "model._features.14.branch2.0.bn.running_var", "model._features.14.branch2.0.bn.num_batches_tracked", "model._features.14.branch2.1.conv.weight", "model._features.14.branch2.1.bn.weight", "model._features.14.branch2.1.bn.bias", "model._features.14.branch2.1.bn.running_mean", "model._features.14.branch2.1.bn.running_var", "model._features.14.branch2.1.bn.num_batches_tracked", "model._features.14.branch2.2.conv.weight", "model._features.14.branch2.2.bn.weight", "model._features.14.branch2.2.bn.bias", "model._features.14.branch2.2.bn.running_mean", "model._features.14.branch2.2.bn.running_var", "model._features.14.branch2.2.bn.num_batches_tracked", "model._features.14.branch2.3.conv.weight", "model._features.14.branch2.3.bn.weight", "model._features.14.branch2.3.bn.bias", "model._features.14.branch2.3.bn.running_mean", "model._features.14.branch2.3.bn.running_var", "model._features.14.branch2.3.bn.num_batches_tracked", "model._features.14.branch2.4.conv.weight", "model._features.14.branch2.4.bn.weight", "model._features.14.branch2.4.bn.bias", "model._features.14.branch2.4.bn.running_mean", "model._features.14.branch2.4.bn.running_var", "model._features.14.branch2.4.bn.num_batches_tracked", "model._features.14.branch3.1.conv.weight", "model._features.14.branch3.1.bn.weight", "model._features.14.branch3.1.bn.bias", "model._features.14.branch3.1.bn.running_mean", "model._features.14.branch3.1.bn.running_var", "model._features.14.branch3.1.bn.num_batches_tracked", "model._features.15.branch0.conv.weight", "model._features.15.branch0.bn.weight", "model._features.15.branch0.bn.bias", "model._features.15.branch0.bn.running_mean", "model._features.15.branch0.bn.running_var", "model._features.15.branch0.bn.num_batches_tracked", "model._features.15.branch1.0.conv.weight", "model._features.15.branch1.0.bn.weight", "model._features.15.branch1.0.bn.bias", "model._features.15.branch1.0.bn.running_mean", "model._features.15.branch1.0.bn.running_var", "model._features.15.branch1.0.bn.num_batches_tracked", "model._features.15.branch1.1.conv.weight", "model._features.15.branch1.1.bn.weight", "model._features.15.branch1.1.bn.bias", "model._features.15.branch1.1.bn.running_mean", "model._features.15.branch1.1.bn.running_var", "model._features.15.branch1.1.bn.num_batches_tracked", "model._features.15.branch1.2.conv.weight", "model._features.15.branch1.2.bn.weight", "model._features.15.branch1.2.bn.bias", "model._features.15.branch1.2.bn.running_mean", "model._features.15.branch1.2.bn.running_var", "model._features.15.branch1.2.bn.num_batches_tracked", "model._features.15.branch2.0.conv.weight", "model._features.15.branch2.0.bn.weight", "model._features.15.branch2.0.bn.bias", "model._features.15.branch2.0.bn.running_mean", "model._features.15.branch2.0.bn.running_var", "model._features.15.branch2.0.bn.num_batches_tracked", "model._features.15.branch2.1.conv.weight", "model._features.15.branch2.1.bn.weight", "model._features.15.branch2.1.bn.bias", "model._features.15.branch2.1.bn.running_mean", "model._features.15.branch2.1.bn.running_var", "model._features.15.branch2.1.bn.num_batches_tracked", "model._features.15.branch2.2.conv.weight", "model._features.15.branch2.2.bn.weight", "model._features.15.branch2.2.bn.bias", "model._features.15.branch2.2.bn.running_mean", "model._features.15.branch2.2.bn.running_var", "model._features.15.branch2.2.bn.num_batches_tracked", "model._features.15.branch2.3.conv.weight", "model._features.15.branch2.3.bn.weight", "model._features.15.branch2.3.bn.bias", "model._features.15.branch2.3.bn.running_mean", "model._features.15.branch2.3.bn.running_var", "model._features.15.branch2.3.bn.num_batches_tracked", "model._features.15.branch2.4.conv.weight", "model._features.15.branch2.4.bn.weight", "model._features.15.branch2.4.bn.bias", "model._features.15.branch2.4.bn.running_mean", "model._features.15.branch2.4.bn.running_var", "model._features.15.branch2.4.bn.num_batches_tracked", "model._features.15.branch3.1.conv.weight", "model._features.15.branch3.1.bn.weight", "model._features.15.branch3.1.bn.bias", "model._features.15.branch3.1.bn.running_mean", "model._features.15.branch3.1.bn.running_var", "model._features.15.branch3.1.bn.num_batches_tracked", "model._features.16.branch0.conv.weight", "model._features.16.branch0.bn.weight", "model._features.16.branch0.bn.bias", "model._features.16.branch0.bn.running_mean", "model._features.16.branch0.bn.running_var", "model._features.16.branch0.bn.num_batches_tracked", "model._features.16.branch1.0.conv.weight", "model._features.16.branch1.0.bn.weight", "model._features.16.branch1.0.bn.bias", "model._features.16.branch1.0.bn.running_mean", "model._features.16.branch1.0.bn.running_var", "model._features.16.branch1.0.bn.num_batches_tracked", "model._features.16.branch1.1.conv.weight", "model._features.16.branch1.1.bn.weight", "model._features.16.branch1.1.bn.bias", "model._features.16.branch1.1.bn.running_mean", "model._features.16.branch1.1.bn.running_var", "model._features.16.branch1.1.bn.num_batches_tracked", "model._features.16.branch1.2.conv.weight", "model._features.16.branch1.2.bn.weight", "model._features.16.branch1.2.bn.bias", "model._features.16.branch1.2.bn.running_mean", "model._features.16.branch1.2.bn.running_var", "model._features.16.branch1.2.bn.num_batches_tracked", "model._features.16.branch2.0.conv.weight", "model._features.16.branch2.0.bn.weight", "model._features.16.branch2.0.bn.bias", "model._features.16.branch2.0.bn.running_mean", "model._features.16.branch2.0.bn.running_var", "model._features.16.branch2.0.bn.num_batches_tracked", "model._features.16.branch2.1.conv.weight", "model._features.16.branch2.1.bn.weight", "model._features.16.branch2.1.bn.bias", "model._features.16.branch2.1.bn.running_mean", "model._features.16.branch2.1.bn.running_var", "model._features.16.branch2.1.bn.num_batches_tracked", "model._features.16.branch2.2.conv.weight", "model._features.16.branch2.2.bn.weight", "model._features.16.branch2.2.bn.bias", "model._features.16.branch2.2.bn.running_mean", "model._features.16.branch2.2.bn.running_var", "model._features.16.branch2.2.bn.num_batches_tracked", "model._features.16.branch2.3.conv.weight", "model._features.16.branch2.3.bn.weight", "model._features.16.branch2.3.bn.bias", "model._features.16.branch2.3.bn.running_mean", "model._features.16.branch2.3.bn.running_var", "model._features.16.branch2.3.bn.num_batches_tracked", "model._features.16.branch2.4.conv.weight", "model._features.16.branch2.4.bn.weight", "model._features.16.branch2.4.bn.bias", "model._features.16.branch2.4.bn.running_mean", "model._features.16.branch2.4.bn.running_var", "model._features.16.branch2.4.bn.num_batches_tracked", "model._features.16.branch3.1.conv.weight", "model._features.16.branch3.1.bn.weight", "model._features.16.branch3.1.bn.bias", "model._features.16.branch3.1.bn.running_mean", "model._features.16.branch3.1.bn.running_var", "model._features.16.branch3.1.bn.num_batches_tracked", "model._features.17.branch0.conv.weight", "model._features.17.branch0.bn.weight", "model._features.17.branch0.bn.bias", "model._features.17.branch0.bn.running_mean", "model._features.17.branch0.bn.running_var", "model._features.17.branch0.bn.num_batches_tracked", "model._features.17.branch1.0.conv.weight", "model._features.17.branch1.0.bn.weight", "model._features.17.branch1.0.bn.bias", "model._features.17.branch1.0.bn.running_mean", "model._features.17.branch1.0.bn.running_var", "model._features.17.branch1.0.bn.num_batches_tracked", "model._features.17.branch1.1.conv.weight", "model._features.17.branch1.1.bn.weight", "model._features.17.branch1.1.bn.bias", "model._features.17.branch1.1.bn.running_mean", "model._features.17.branch1.1.bn.running_var", "model._features.17.branch1.1.bn.num_batches_tracked", "model._features.17.branch1.2.conv.weight", "model._features.17.branch1.2.bn.weight", "model._features.17.branch1.2.bn.bias", "model._features.17.branch1.2.bn.running_mean", "model._features.17.branch1.2.bn.running_var", "model._features.17.branch1.2.bn.num_batches_tracked", "model._features.17.branch2.0.conv.weight", "model._features.17.branch2.0.bn.weight", "model._features.17.branch2.0.bn.bias", "model._features.17.branch2.0.bn.running_mean", "model._features.17.branch2.0.bn.running_var", "model._features.17.branch2.0.bn.num_batches_tracked", "model._features.17.branch2.1.conv.weight", "model._features.17.branch2.1.bn.weight", "model._features.17.branch2.1.bn.bias", "model._features.17.branch2.1.bn.running_mean", "model._features.17.branch2.1.bn.running_var", "model._features.17.branch2.1.bn.num_batches_tracked", "model._features.17.branch2.2.conv.weight", "model._features.17.branch2.2.bn.weight", "model._features.17.branch2.2.bn.bias", "model._features.17.branch2.2.bn.running_mean", "model._features.17.branch2.2.bn.running_var", "model._features.17.branch2.2.bn.num_batches_tracked", "model._features.17.branch2.3.conv.weight", "model._features.17.branch2.3.bn.weight", "model._features.17.branch2.3.bn.bias", "model._features.17.branch2.3.bn.running_mean", "model._features.17.branch2.3.bn.running_var", "model._features.17.branch2.3.bn.num_batches_tracked", "model._features.17.branch2.4.conv.weight", "model._features.17.branch2.4.bn.weight", "model._features.17.branch2.4.bn.bias", "model._features.17.branch2.4.bn.running_mean", "model._features.17.branch2.4.bn.running_var", "model._features.17.branch2.4.bn.num_batches_tracked", "model._features.17.branch3.1.conv.weight", "model._features.17.branch3.1.bn.weight", "model._features.17.branch3.1.bn.bias", "model._features.17.branch3.1.bn.running_mean", "model._features.17.branch3.1.bn.running_var", "model._features.17.branch3.1.bn.num_batches_tracked", "model._features.18.branch0.0.conv.weight", "model._features.18.branch0.0.bn.weight", "model._features.18.branch0.0.bn.bias", "model._features.18.branch0.0.bn.running_mean", "model._features.18.branch0.0.bn.running_var", "model._features.18.branch0.0.bn.num_batches_tracked", "model._features.18.branch0.1.conv.weight", "model._features.18.branch0.1.bn.weight", "model._features.18.branch0.1.bn.bias", "model._features.18.branch0.1.bn.running_mean", "model._features.18.branch0.1.bn.running_var", "model._features.18.branch0.1.bn.num_batches_tracked", "model._features.18.branch1.0.conv.weight", "model._features.18.branch1.0.bn.weight", "model._features.18.branch1.0.bn.bias", "model._features.18.branch1.0.bn.running_mean", "model._features.18.branch1.0.bn.running_var", "model._features.18.branch1.0.bn.num_batches_tracked", "model._features.18.branch1.1.conv.weight", "model._features.18.branch1.1.bn.weight", "model._features.18.branch1.1.bn.bias", "model._features.18.branch1.1.bn.running_mean", "model._features.18.branch1.1.bn.running_var", "model._features.18.branch1.1.bn.num_batches_tracked", "model._features.18.branch1.2.conv.weight", "model._features.18.branch1.2.bn.weight", "model._features.18.branch1.2.bn.bias", "model._features.18.branch1.2.bn.running_mean", "model._features.18.branch1.2.bn.running_var", "model._features.18.branch1.2.bn.num_batches_tracked", "model._features.18.branch1.3.conv.weight", "model._features.18.branch1.3.bn.weight", "model._features.18.branch1.3.bn.bias", "model._features.18.branch1.3.bn.running_mean", "model._features.18.branch1.3.bn.running_var", "model._features.18.branch1.3.bn.num_batches_tracked", "model._features.19.branch0.conv.weight", "model._features.19.branch0.bn.weight", "model._features.19.branch0.bn.bias", "model._features.19.branch0.bn.running_mean", "model._features.19.branch0.bn.running_var", "model._features.19.branch0.bn.num_batches_tracked", "model._features.19.branch1_0.conv.weight", "model._features.19.branch1_0.bn.weight", "model._features.19.branch1_0.bn.bias", "model._features.19.branch1_0.bn.running_mean", "model._features.19.branch1_0.bn.running_var", "model._features.19.branch1_0.bn.num_batches_tracked", "model._features.19.branch1_1a.conv.weight", "model._features.19.branch1_1a.bn.weight", "model._features.19.branch1_1a.bn.bias", "model._features.19.branch1_1a.bn.running_mean", "model._features.19.branch1_1a.bn.running_var", "model._features.19.branch1_1a.bn.num_batches_tracked", "model._features.19.branch1_1b.conv.weight", "model._features.19.branch1_1b.bn.weight", "model._features.19.branch1_1b.bn.bias", "model._features.19.branch1_1b.bn.running_mean", "model._features.19.branch1_1b.bn.running_var", "model._features.19.branch1_1b.bn.num_batches_tracked", "model._features.19.branch2_0.conv.weight", "model._features.19.branch2_0.bn.weight", "model._features.19.branch2_0.bn.bias", "model._features.19.branch2_0.bn.running_mean", "model._features.19.branch2_0.bn.running_var", "model._features.19.branch2_0.bn.num_batches_tracked", "model._features.19.branch2_1.conv.weight", "model._features.19.branch2_1.bn.weight", "model._features.19.branch2_1.bn.bias", "model._features.19.branch2_1.bn.running_mean", "model._features.19.branch2_1.bn.running_var", "model._features.19.branch2_1.bn.num_batches_tracked", "model._features.19.branch2_2.conv.weight", "model._features.19.branch2_2.bn.weight", "model._features.19.branch2_2.bn.bias", "model._features.19.branch2_2.bn.running_mean", "model._features.19.branch2_2.bn.running_var", "model._features.19.branch2_2.bn.num_batches_tracked", "model._features.19.branch2_3a.conv.weight", "model._features.19.branch2_3a.bn.weight", "model._features.19.branch2_3a.bn.bias", "model._features.19.branch2_3a.bn.running_mean", "model._features.19.branch2_3a.bn.running_var", "model._features.19.branch2_3a.bn.num_batches_tracked", "model._features.19.branch2_3b.conv.weight", "model._features.19.branch2_3b.bn.weight", "model._features.19.branch2_3b.bn.bias", "model._features.19.branch2_3b.bn.running_mean", "model._features.19.branch2_3b.bn.running_var", "model._features.19.branch2_3b.bn.num_batches_tracked", "model._features.19.branch3.1.conv.weight", "model._features.19.branch3.1.bn.weight", "model._features.19.branch3.1.bn.bias", "model._features.19.branch3.1.bn.running_mean", "model._features.19.branch3.1.bn.running_var", "model._features.19.branch3.1.bn.num_batches_tracked", "model._features.20.branch0.conv.weight", "model._features.20.branch0.bn.weight", "model._features.20.branch0.bn.bias", "model._features.20.branch0.bn.running_mean", "model._features.20.branch0.bn.running_var", "model._features.20.branch0.bn.num_batches_tracked", "model._features.20.branch1_0.conv.weight", "model._features.20.branch1_0.bn.weight", "model._features.20.branch1_0.bn.bias", "model._features.20.branch1_0.bn.running_mean", "model._features.20.branch1_0.bn.running_var", "model._features.20.branch1_0.bn.num_batches_tracked", "model._features.20.branch1_1a.conv.weight", "model._features.20.branch1_1a.bn.weight", "model._features.20.branch1_1a.bn.bias", "model._features.20.branch1_1a.bn.running_mean", "model._features.20.branch1_1a.bn.running_var", "model._features.20.branch1_1a.bn.num_batches_tracked", "model._features.20.branch1_1b.conv.weight", "model._features.20.branch1_1b.bn.weight", "model._features.20.branch1_1b.bn.bias", "model._features.20.branch1_1b.bn.running_mean", "model._features.20.branch1_1b.bn.running_var", "model._features.20.branch1_1b.bn.num_batches_tracked", "model._features.20.branch2_0.conv.weight", "model._features.20.branch2_0.bn.weight", "model._features.20.branch2_0.bn.bias", "model._features.20.branch2_0.bn.running_mean", "model._features.20.branch2_0.bn.running_var", "model._features.20.branch2_0.bn.num_batches_tracked", "model._features.20.branch2_1.conv.weight", "model._features.20.branch2_1.bn.weight", "model._features.20.branch2_1.bn.bias", "model._features.20.branch2_1.bn.running_mean", "model._features.20.branch2_1.bn.running_var", "model._features.20.branch2_1.bn.num_batches_tracked", "model._features.20.branch2_2.conv.weight", "model._features.20.branch2_2.bn.weight", "model._features.20.branch2_2.bn.bias", "model._features.20.branch2_2.bn.running_mean", "model._features.20.branch2_2.bn.running_var", "model._features.20.branch2_2.bn.num_batches_tracked", "model._features.20.branch2_3a.conv.weight", "model._features.20.branch2_3a.bn.weight", "model._features.20.branch2_3a.bn.bias", "model._features.20.branch2_3a.bn.running_mean", "model._features.20.branch2_3a.bn.running_var", "model._features.20.branch2_3a.bn.num_batches_tracked", "model._features.20.branch2_3b.conv.weight", "model._features.20.branch2_3b.bn.weight", "model._features.20.branch2_3b.bn.bias", "model._features.20.branch2_3b.bn.running_mean", "model._features.20.branch2_3b.bn.running_var", "model._features.20.branch2_3b.bn.num_batches_tracked", "model._features.20.branch3.1.conv.weight", "model._features.20.branch3.1.bn.weight", "model._features.20.branch3.1.bn.bias", "model._features.20.branch3.1.bn.running_mean", "model._features.20.branch3.1.bn.running_var", "model._features.20.branch3.1.bn.num_batches_tracked", "model._features.21.branch0.conv.weight", "model._features.21.branch0.bn.weight", "model._features.21.branch0.bn.bias", "model._features.21.branch0.bn.running_mean", "model._features.21.branch0.bn.running_var", "model._features.21.branch0.bn.num_batches_tracked", "model._features.21.branch1_0.conv.weight", "model._features.21.branch1_0.bn.weight", "model._features.21.branch1_0.bn.bias", "model._features.21.branch1_0.bn.running_mean", "model._features.21.branch1_0.bn.running_var", "model._features.21.branch1_0.bn.num_batches_tracked", "model._features.21.branch1_1a.conv.weight", "model._features.21.branch1_1a.bn.weight", "model._features.21.branch1_1a.bn.bias", "model._features.21.branch1_1a.bn.running_mean", "model._features.21.branch1_1a.bn.running_var", "model._features.21.branch1_1a.bn.num_batches_tracked", "model._features.21.branch1_1b.conv.weight", "model._features.21.branch1_1b.bn.weight", "model._features.21.branch1_1b.bn.bias", "model._features.21.branch1_1b.bn.running_mean", "model._features.21.branch1_1b.bn.running_var", "model._features.21.branch1_1b.bn.num_batches_tracked", "model._features.21.branch2_0.conv.weight", "model._features.21.branch2_0.bn.weight", "model._features.21.branch2_0.bn.bias", "model._features.21.branch2_0.bn.running_mean", "model._features.21.branch2_0.bn.running_var", "model._features.21.branch2_0.bn.num_batches_tracked", "model._features.21.branch2_1.conv.weight", "model._features.21.branch2_1.bn.weight", "model._features.21.branch2_1.bn.bias", "model._features.21.branch2_1.bn.running_mean", "model._features.21.branch2_1.bn.running_var", "model._features.21.branch2_1.bn.num_batches_tracked", "model._features.21.branch2_2.conv.weight", "model._features.21.branch2_2.bn.weight", "model._features.21.branch2_2.bn.bias", "model._features.21.branch2_2.bn.running_mean", "model._features.21.branch2_2.bn.running_var", "model._features.21.branch2_2.bn.num_batches_tracked", "model._features.21.branch2_3a.conv.weight", "model._features.21.branch2_3a.bn.weight", "model._features.21.branch2_3a.bn.bias", "model._features.21.branch2_3a.bn.running_mean", "model._features.21.branch2_3a.bn.running_var", "model._features.21.branch2_3a.bn.num_batches_tracked", "model._features.21.branch2_3b.conv.weight", "model._features.21.branch2_3b.bn.weight", "model._features.21.branch2_3b.bn.bias", "model._features.21.branch2_3b.bn.running_mean", "model._features.21.branch2_3b.bn.running_var", "model._features.21.branch2_3b.bn.num_batches_tracked", "model._features.21.branch3.1.conv.weight", "model._features.21.branch3.1.bn.weight", "model._features.21.branch3.1.bn.bias", "model._features.21.branch3.1.bn.running_mean", "model._features.21.branch3.1.bn.running_var", "model._features.21.branch3.1.bn.num_batches_tracked", "model._classifier.weight", "model._classifier.bias".
how can i solve it? i am waiting for your reply

How to replace last FC layer by 1x1 convolution?

Thanks for sharing a great API. I want to perform FCN (Fully convolutional network for semantic segmentation) using your API. It can be done by replacing the last Fully connected layer by 1x1 convolutional layer. Let's take resnet-18 as an example, how could I modify it to perform semantic segmentation. I think your classification example is good, and if you make an example for segmentation then it will be very good.

Strange mean and std for inception_v4.

Hi, and sorry if it's stupid question.
But value for mean of inception_v4 equals [0.5, 0.5, 05].
And value for std of inception_v4 equals [0.5, 0.5, 05].
Transformation of tensor by Normalize with this values map 0 to -1 and 255 to 509.
Is it normal?

Training loss might not be summed

I think the training loss is not accumulated over all the batches, not sure if I am missing something here.

train_loss added below

``

def train(epoch):
    model.train()
    train_loss=0
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss +=loss.data[0]
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), train_loss/(batch_idx+1) ))

Output size issue in alexnet and inception

AlexNet

File "/lib/python3.6/site-packages/cnn_finetune/base.py", line 158, in calculate_classifier_in_features
output = original_model.features(input_var)

File "/lib/python3.6/site-packages/torch/nn/modules/pooling.py", line 143, in forward
self.return_indices)

File "/lib/python3.6/site-packages/torch/nn/functional.py", line 334, in max_pool2d
ret = torch._C._nn.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

RuntimeError: Given input size: (256x1x1). Calculated output size: (256x0x0). Output size is too small at /opt/conda/conda-bld/pytorch_1518243271935/work/torch/lib/THNN/generic

inceptionresnetv2
File "/lib/python3.6/site-packages/torch/nn/functional.py", line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given input size: (320, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

inception_v3
line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given input size: (288, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

inception_v4
Given input size: (384, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

NB1. For AlexNet, changing the train_original_classifier to True (it was False in all the above) gave the following error:
File "/lib/python3.6/site-packages/cnn_finetune/contrib/torchvision.py", line 93, in check_args
'For the original classifier '

Exception: For the original classifier input_size value must be (224, 224)

NB2. pre_trained_model = True in all the above

Auto Flattening of features

Hi,

I found this annoying feature/bug as one may like to:

def forward(self, x):
    x = self.features(x)
    if self.pool is not None:
        x = self.pool(x)
    if self.dropout is not None:
        x = self.dropout(x)
    if self.flatten_features_output:
        x = x.view(x.size(0), -1)
     x = self.classifier(x)
     return x

Why do you have self.flatten_features_output if it is always going to be true? There is no access to that variable. And once the pooling is set to None, the tensor is automatically flattened instead of giving access to the layer.
Any plans of changing it? Or giving access?

Do the parameters in the loaded pretrained-model update according to the Grad?

When using a custom classifier like the one in your guide:

import torch.nn as nn

def make_classifier(in_features, num_classes):
return nn.Sequential(
nn.Linear(in_features, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)

I wonder if the parameters in the pretrained-model update or it just trains a custom classifier on top of the net?
if I only want to train a custom classifier on top of the net, what should I do?
Much Thanks.

mean and std used in the transform of the Cifar10 example

In the Cifar10 example, the transform uses the mean and std of the original pre-trained (on ImageNet) model(s) ; how would this transform affect the learning from new data? Shouldn't the mean and std obtained from the new data (Cifar10 in the example)?

transform = transforms.Compose([
       transforms.ToTensor(),       
       transforms.Normalize(
           mean=model.original_model_info.mean,
           std=model.original_model_info.std),
   ])

illegal instruction (core dumped)

It seems that the default res-net works. Any model that needs to be downloaded (vgg, alexnet) cause this exception. Some examples below:

~/pytorch-cnn-finetune/examples$ python cifar10.py --model-name vgg19
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/ubuntu/.torch/models/vgg19-dcbb9e9d.pth
100%|‚574673361/574673361 [00:20<00:00, 28293664.65it/s]
Illegal instruction (core dumped)

~/pytorch-cnn-finetune/examples$ python cifar10.py --model-name vgg16
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/ubuntu/.torch/models/vgg16-397923af.pth
100%|553433881/553433881 [00:05<00:00, 98388620.83it/s]
Illegal instruction (core dumped)

Some warnings due to updating PyTorch to version 0.4.0

After updating PyTorch to version 0.4.0, a few warnings appeared iin the Cifar10 example.
1- UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
test_loss += criterion(output, target).data[0]
corrected to: test_loss += criterion(output, target).item()
total_loss += loss.data[0]
corrected to: total_loss += loss.item()

2- UserWarning: volatile was removed and now has no effect. Use with torch.no_grad(): instead.
data, target = Variable(data, volatile=True), Variable(target)
Note sure how to correct this.... with torch.no_grad():

ask a question

What does your network “output ” include?
“output = model(data) ”
Is “output.data () ” a logical value?

I want to get the classification probability, how do I get it?
Thanks!

Finetune on multi-GPUs

I want to finetune the models on multi-GPUs, as following the office documents, I try this:
model = make_model( args.model_name, pretrained=True, num_classes=len(classes), dropout_p=args.dropout_p, use_original_classifier=True )
model = nn.DataParallel(model)
model = model.to(device)

But get error as:
Traceback (most recent call last): File "pg_cls.py", line 83, in <module> mean=model.original_model_info.mean, File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __getattr__ type(self).__name__, name)) AttributeError: 'DataParallel' object has no attribute 'original_model_info'

Error(s) in loading state_dict for Xception

As your document, the xception has error from the issue
Cadene/pretrained-models.pytorch#62
I am using your API and xception network size of 256. I got the error as the below log. Could you tell me how could I fix it using your API
This is my code

model = make_model(
    'xception',
    pretrained=True,
    num_classes=100,
    dropout_p=0.2,
    input_size=(256, 256)
)

This is log

RuntimeError: Error(s) in loading state_dict for Xception:
	size mismatch for block1.rep.0.pointwise.weight: copying a param of torch.Size([128, 64, 1, 1]) from checkpoint, where the shape is torch.Size([128, 64]) in current model.
	size mismatch for block1.rep.3.pointwise.weight: copying a param of torch.Size([128, 128, 1, 1]) from checkpoint, where the shape is torch.Size([128, 128]) in current model.
	size mismatch for block2.rep.1.pointwise.weight: copying a param of torch.Size([256, 128, 1, 1]) from checkpoint, where the shape is torch.Size([256, 128]) in current model.
	size mismatch for block2.rep.4.pointwise.weight: copying a param of torch.Size([256, 256, 1, 1]) from checkpoint, where the shape is torch.Size([256, 256]) in current model.
	size mismatch for block3.rep.1.pointwise.weight: copying a param of torch.Size([728, 256, 1, 1]) from checkpoint, where the shape is torch.Size([728, 256]) in current model.
	size mismatch for block3.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block12.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block12.rep.4.pointwise.weight: copying a param of torch.Size([1024, 728, 1, 1]) from checkpoint, where the shape is torch.Size([1024, 728]) in current model.
	size mismatch for conv3.pointwise.weight: copying a param of torch.Size([1536, 1024, 1, 1]) from checkpoint, where the shape is torch.Size([1536, 1024]) in current model.
	size mismatch for conv4.pointwise.weight: copying a param of torch.Size([2048, 1536, 1, 1]) from checkpoint, where the shape is torch.Size([2048, 1536]) in current model.

'DPN' object has no attribute 'input_space' Cifar10 example

When setting the pre_trained_model to False, running 'dpns' gave the error 'DPN' object has no attribute 'input_space' in Cifar10 example. Except Resntes, all models gave this error. Using the attribute input_size=(32, 32) in make_model resolved the issue. [32x32] is the cifar10 resolution for each band.

model = make_model(
    args.model_name,
    pretrained=True,
    num_classes=len(classes),
    dropout_p=args.dropout_p,
    input_size=(32, 32) 
)

However, this 'input_size' gave the error: ResNet' object has no attribute 'features'
for all the Resnets. Thus, an if statement may be used to make-the-model.
NB. For VGG* and squeezenet*, we need to provide the input-size in any case, Exception: You must provide input_size, e.g. make_model(vgg11, num_classes=10, pretrained=True, input_size=(224, 224)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.