Giter VIP home page Giter VIP logo

pytorch-cnn-finetune's People


cgnorthcutt avatar creafz avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-cnn-finetune's Issues

'DPN' object has no attribute 'input_space' Cifar10 example

When setting the pre_trained_model to False, running 'dpns' gave the error 'DPN' object has no attribute 'input_space' in Cifar10 example. Except Resntes, all models gave this error. Using the attribute input_size=(32, 32) in make_model resolved the issue. [32x32] is the cifar10 resolution for each band.

model = make_model(
    input_size=(32, 32) 

However, this 'input_size' gave the error: ResNet' object has no attribute 'features'
for all the Resnets. Thus, an if statement may be used to make-the-model.
NB. For VGG* and squeezenet*, we need to provide the input-size in any case, Exception: You must provide input_size, e.g. make_model(vgg11, num_classes=10, pretrained=True, input_size=(224, 224)

Access to fc layer

I need to access the fc layer in the ResNet model. With the standard ResNet, this is done using
However, with the ones from this repo, I am getting this error:

In [33]: model.fc.weight

raceback (most recent call last):

File "<ipython-input-33-ca343f58d0f7>", line 1, in <module>

File "//anaconda3/lib/python3.6/site-packages/torch/nn/modules/", line 535, in __getattr__
  type(self).__name__, name))

AttributeError: 'ResNetWrapper' object has no attribute 'fc'

Is there a way around this?

Auto Flattening of features


I found this annoying feature/bug as one may like to:

def forward(self, x):
    x = self.features(x)
    if self.pool is not None:
        x = self.pool(x)
    if self.dropout is not None:
        x = self.dropout(x)
    if self.flatten_features_output:
        x = x.view(x.size(0), -1)
     x = self.classifier(x)
     return x

Why do you have self.flatten_features_output if it is always going to be true? There is no access to that variable. And once the pooling is set to None, the tensor is automatically flattened instead of giving access to the layer.
Any plans of changing it? Or giving access?

Training loss might not be summed

I think the training loss is not accumulated over all the batches, not sure if I am missing something here.

train_loss added below


def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss = criterion(output, target)
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), train_loss/(batch_idx+1) ))

Finetune on multi-GPUs

I want to finetune the models on multi-GPUs, as following the office documents, I try this:
model = make_model( args.model_name, pretrained=True, num_classes=len(classes), dropout_p=args.dropout_p, use_original_classifier=True )
model = nn.DataParallel(model)
model =

But get error as:
Traceback (most recent call last): File "", line 83, in <module> mean=model.original_model_info.mean, File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/", line 532, in __getattr__ type(self).__name__, name)) AttributeError: 'DataParallel' object has no attribute 'original_model_info'

Do the parameters in the loaded pretrained-model update according to the Grad?

When using a custom classifier like the one in your guide:

import torch.nn as nn

def make_classifier(in_features, num_classes):
return nn.Sequential(
nn.Linear(in_features, 4096),
nn.Linear(4096, num_classes),

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)

I wonder if the parameters in the pretrained-model update or it just trains a custom classifier on top of the net?
if I only want to train a custom classifier on top of the net, what should I do?
Much Thanks.

Strange mean and std for inception_v4.

Hi, and sorry if it's stupid question.
But value for mean of inception_v4 equals [0.5, 0.5, 05].
And value for std of inception_v4 equals [0.5, 0.5, 05].
Transformation of tensor by Normalize with this values map 0 to -1 and 255 to 509.
Is it normal?

Error(s) in loading state_dict for Xception

As your document, the xception has error from the issue
I am using your API and xception network size of 256. I got the error as the below log. Could you tell me how could I fix it using your API
This is my code

model = make_model(
    input_size=(256, 256)

This is log

RuntimeError: Error(s) in loading state_dict for Xception:
	size mismatch for block1.rep.0.pointwise.weight: copying a param of torch.Size([128, 64, 1, 1]) from checkpoint, where the shape is torch.Size([128, 64]) in current model.
	size mismatch for block1.rep.3.pointwise.weight: copying a param of torch.Size([128, 128, 1, 1]) from checkpoint, where the shape is torch.Size([128, 128]) in current model.
	size mismatch for block2.rep.1.pointwise.weight: copying a param of torch.Size([256, 128, 1, 1]) from checkpoint, where the shape is torch.Size([256, 128]) in current model.
	size mismatch for block2.rep.4.pointwise.weight: copying a param of torch.Size([256, 256, 1, 1]) from checkpoint, where the shape is torch.Size([256, 256]) in current model.
	size mismatch for block3.rep.1.pointwise.weight: copying a param of torch.Size([728, 256, 1, 1]) from checkpoint, where the shape is torch.Size([728, 256]) in current model.
	size mismatch for block3.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block4.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block5.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block6.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block7.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block8.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block9.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block10.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.4.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block11.rep.7.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block12.rep.1.pointwise.weight: copying a param of torch.Size([728, 728, 1, 1]) from checkpoint, where the shape is torch.Size([728, 728]) in current model.
	size mismatch for block12.rep.4.pointwise.weight: copying a param of torch.Size([1024, 728, 1, 1]) from checkpoint, where the shape is torch.Size([1024, 728]) in current model.
	size mismatch for conv3.pointwise.weight: copying a param of torch.Size([1536, 1024, 1, 1]) from checkpoint, where the shape is torch.Size([1536, 1024]) in current model.
	size mismatch for conv4.pointwise.weight: copying a param of torch.Size([2048, 1536, 1, 1]) from checkpoint, where the shape is torch.Size([2048, 1536]) in current model.

how can i load the trained-model

i finetune the pre-trained inception_v4 model with 3 classes,by model=make_model('inception_v4',3,pretrained=True)
, and saved the trained-model as 'model.pth' by,'model.pth')
but when i restored the model by
model = make_model('inception_v4',3,pretrained=False)
it goes wrong,and the error is :
Unexpected key(s) in state_dict: "model._features.0.conv.weight", "", "", "", "", "", "model._features.1.conv.weight", "", "", "", "", "", "model._features.2.conv.weight", "", "", "", "", "", "model._features.3.conv.conv.weight", "", "", "", "", "", "model._features.4.branch0.0.conv.weight", "", "", "", "", "", "model._features.4.branch0.1.conv.weight", "", "", "", "", "", "model._features.4.branch1.0.conv.weight", "", "", "", "", "", "model._features.4.branch1.1.conv.weight", "", "", "", "", "", "model._features.4.branch1.2.conv.weight", "", "", "", "", "", "model._features.4.branch1.3.conv.weight", "", "", "", "", "", "model._features.5.conv.conv.weight", "", "", "", "", "", "model._features.6.branch0.conv.weight", "", "", "", "", "", "model._features.6.branch1.0.conv.weight", "", "", "", "", "", "model._features.6.branch1.1.conv.weight", "", "", "", "", "", "model._features.6.branch2.0.conv.weight", "", "", "", "", "", "model._features.6.branch2.1.conv.weight", "", "", "", "", "", "model._features.6.branch2.2.conv.weight", "", "", "", "", "", "model._features.6.branch3.1.conv.weight", "", "", "", "", "", "model._features.7.branch0.conv.weight", "", "", "", "", "", "model._features.7.branch1.0.conv.weight", "", "", "", "", "", "model._features.7.branch1.1.conv.weight", "", "", "", "", "", "model._features.7.branch2.0.conv.weight", "", "", "", "", "", "model._features.7.branch2.1.conv.weight", "", "", "", "", "", "model._features.7.branch2.2.conv.weight", "", "", "", "", "", "model._features.7.branch3.1.conv.weight", "", "", "", "", "", "model._features.8.branch0.conv.weight", "", "", "", "", "", "model._features.8.branch1.0.conv.weight", "", "", "", "", "", "model._features.8.branch1.1.conv.weight", "", "", "", "", "", "model._features.8.branch2.0.conv.weight", "", "", "", "", "", "model._features.8.branch2.1.conv.weight", "", "", "", "", "", "model._features.8.branch2.2.conv.weight", "", "", "", "", "", "model._features.8.branch3.1.conv.weight", "", "", "", "", "", "model._features.9.branch0.conv.weight", "", "", "", "", "", "model._features.9.branch1.0.conv.weight", "", "", "", "", "", "model._features.9.branch1.1.conv.weight", "", "", "", "", "", "model._features.9.branch2.0.conv.weight", "", "", "", "", "", "model._features.9.branch2.1.conv.weight", "", "", "", "", "", "model._features.9.branch2.2.conv.weight", "", "", "", "", "", "model._features.9.branch3.1.conv.weight", "", "", "", "", "", "model._features.10.branch0.conv.weight", "", "", "", "", "", "model._features.10.branch1.0.conv.weight", "", "", "", "", "", "model._features.10.branch1.1.conv.weight", "", "", "", "", "", "model._features.10.branch1.2.conv.weight", "", "", "", "", "", "model._features.11.branch0.conv.weight", "", "", "", "", "", "model._features.11.branch1.0.conv.weight", "", "", "", "", "", "model._features.11.branch1.1.conv.weight", "", "", "", "", "", "model._features.11.branch1.2.conv.weight", "", "", "", "", "", "model._features.11.branch2.0.conv.weight", "", "", "", "", "", "model._features.11.branch2.1.conv.weight", "", "", "", "", "", "model._features.11.branch2.2.conv.weight", "", "", "", "", "", "model._features.11.branch2.3.conv.weight", "", "", "", "", "", "model._features.11.branch2.4.conv.weight", "", "", "", "", "", "model._features.11.branch3.1.conv.weight", "", "", "", "", "", "model._features.12.branch0.conv.weight", "", "", "", "", "", "model._features.12.branch1.0.conv.weight", "", "", "", "", "", "model._features.12.branch1.1.conv.weight", "", "", "", "", "", "model._features.12.branch1.2.conv.weight", "", "", "", "", "", "model._features.12.branch2.0.conv.weight", "", "", "", "", "", "model._features.12.branch2.1.conv.weight", "", "", "", "", "", "model._features.12.branch2.2.conv.weight", "", "", "", "", "", "model._features.12.branch2.3.conv.weight", "", "", "", "", "", "model._features.12.branch2.4.conv.weight", "", "", "", "", "", "model._features.12.branch3.1.conv.weight", "", "", "", "", "", "model._features.13.branch0.conv.weight", "", "", "", "", "", "model._features.13.branch1.0.conv.weight", "", "", "", "", "", "model._features.13.branch1.1.conv.weight", "", "", "", "", "", "model._features.13.branch1.2.conv.weight", "", "", "", "", "", "model._features.13.branch2.0.conv.weight", "", "", "", "", "", "model._features.13.branch2.1.conv.weight", "", "", "", "", "", "model._features.13.branch2.2.conv.weight", "", "", "", "", "", "model._features.13.branch2.3.conv.weight", "", "", "", "", "", "model._features.13.branch2.4.conv.weight", "", "", "", "", "", "model._features.13.branch3.1.conv.weight", "", "", "", "", "", "model._features.14.branch0.conv.weight", "", "", "", "", "", "model._features.14.branch1.0.conv.weight", "", "", "", "", "", "model._features.14.branch1.1.conv.weight", "", "", "", "", "", "model._features.14.branch1.2.conv.weight", "", "", "", "", "", "model._features.14.branch2.0.conv.weight", "", "", "", "", "", "model._features.14.branch2.1.conv.weight", "", "", "", "", "", "model._features.14.branch2.2.conv.weight", "", "", "", "", "", "model._features.14.branch2.3.conv.weight", "", "", "", "", "", "model._features.14.branch2.4.conv.weight", "", "", "", "", "", "model._features.14.branch3.1.conv.weight", "", "", "", "", "", "model._features.15.branch0.conv.weight", "", "", "", "", "", "model._features.15.branch1.0.conv.weight", "", "", "", "", "", "model._features.15.branch1.1.conv.weight", "", "", "", "", "", "model._features.15.branch1.2.conv.weight", "", "", "", "", "", "model._features.15.branch2.0.conv.weight", "", "", "", "", "", "model._features.15.branch2.1.conv.weight", "", "", "", "", "", "model._features.15.branch2.2.conv.weight", "", "", "", "", "", "model._features.15.branch2.3.conv.weight", "", "", "", "", "", "model._features.15.branch2.4.conv.weight", "", "", "", "", "", "model._features.15.branch3.1.conv.weight", "", "", "", "", "", "model._features.16.branch0.conv.weight", "", "", "", "", "", "model._features.16.branch1.0.conv.weight", "", "", "", "", "", "model._features.16.branch1.1.conv.weight", "", "", "", "", "", "model._features.16.branch1.2.conv.weight", "", "", "", "", "", "model._features.16.branch2.0.conv.weight", "", "", "", "", "", "model._features.16.branch2.1.conv.weight", "", "", "", "", "", "model._features.16.branch2.2.conv.weight", "", "", "", "", "", "model._features.16.branch2.3.conv.weight", "", "", "", "", "", "model._features.16.branch2.4.conv.weight", "", "", "", "", "", "model._features.16.branch3.1.conv.weight", "", "", "", "", "", "model._features.17.branch0.conv.weight", "", "", "", "", "", "model._features.17.branch1.0.conv.weight", "", "", "", "", "", "model._features.17.branch1.1.conv.weight", "", "", "", "", "", "model._features.17.branch1.2.conv.weight", "", "", "", "", "", "model._features.17.branch2.0.conv.weight", "", "", "", "", "", "model._features.17.branch2.1.conv.weight", "", "", "", "", "", "model._features.17.branch2.2.conv.weight", "", "", "", "", "", "model._features.17.branch2.3.conv.weight", "", "", "", "", "", "model._features.17.branch2.4.conv.weight", "", "", "", "", "", "model._features.17.branch3.1.conv.weight", "", "", "", "", "", "model._features.18.branch0.0.conv.weight", "", "", "", "", "", "model._features.18.branch0.1.conv.weight", "", "", "", "", "", "model._features.18.branch1.0.conv.weight", "", "", "", "", "", "model._features.18.branch1.1.conv.weight", "", "", "", "", "", "model._features.18.branch1.2.conv.weight", "", "", "", "", "", "model._features.18.branch1.3.conv.weight", "", "", "", "", "", "model._features.19.branch0.conv.weight", "", "", "", "", "", "model._features.19.branch1_0.conv.weight", "", "", "", "", "", "model._features.19.branch1_1a.conv.weight", "", "", "", "", "", "model._features.19.branch1_1b.conv.weight", "", "", "", "", "", "model._features.19.branch2_0.conv.weight", "", "", "", "", "", "model._features.19.branch2_1.conv.weight", "", "", "", "", "", "model._features.19.branch2_2.conv.weight", "", "", "", "", "", "model._features.19.branch2_3a.conv.weight", "", "", "", "", "", "model._features.19.branch2_3b.conv.weight", "", "", "", "", "", "model._features.19.branch3.1.conv.weight", "", "", "", "", "", "model._features.20.branch0.conv.weight", "", "", "", "", "", "model._features.20.branch1_0.conv.weight", "", "", "", "", "", "model._features.20.branch1_1a.conv.weight", "", "", "", "", "", "model._features.20.branch1_1b.conv.weight", "", "", "", "", "", "model._features.20.branch2_0.conv.weight", "", "", "", "", "", "model._features.20.branch2_1.conv.weight", "", "", "", "", "", "model._features.20.branch2_2.conv.weight", "", "", "", "", "", "model._features.20.branch2_3a.conv.weight", "", "", "", "", "", "model._features.20.branch2_3b.conv.weight", "", "", "", "", "", "model._features.20.branch3.1.conv.weight", "", "", "", "", "", "model._features.21.branch0.conv.weight", "", "", "", "", "", "model._features.21.branch1_0.conv.weight", "", "", "", "", "", "model._features.21.branch1_1a.conv.weight", "", "", "", "", "", "model._features.21.branch1_1b.conv.weight", "", "", "", "", "", "model._features.21.branch2_0.conv.weight", "", "", "", "", "", "model._features.21.branch2_1.conv.weight", "", "", "", "", "", "model._features.21.branch2_2.conv.weight", "", "", "", "", "", "model._features.21.branch2_3a.conv.weight", "", "", "", "", "", "model._features.21.branch2_3b.conv.weight", "", "", "", "", "", "model._features.21.branch3.1.conv.weight", "", "", "", "", "", "model._classifier.weight", "model._classifier.bias".
how can i solve it? i am waiting for your reply

save mode

When saving the resnext101_64x4d model, the following error occurred. There is no such problem in saving other models. What is the reason?

Traceback (most recent call last):
File "", line 251, in
File "", line 243, in main
max_acc = test(model, test_loader,max_acc,epoch_test)
File "", line 130, in test,'./models/%s.pth'%args.model_name)
File "/home/zlw/.local/lib/python3.5/site-packages/torch/", line 260, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/zlw/.local/lib/python3.5/site-packages/torch/", line 185, in _with_file_like
return body(f)
File "/home/zlw/.local/lib/python3.5/site-packages/torch/", line 260, in
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/zlw/.local/lib/python3.5/site-packages/torch/", line 332, in _save
_pickle.PicklingError: Can't pickle <function at 0x7ff02de7b0d0>: attribute lookup on pretrainedmodels.models.resnext_features.resnext101_64x4d_features failed

ask a question

What does your network “output ” include?
“output = model(data) ”
Is “ () ” a logical value?

I want to get the classification probability, how do I get it?

Kernel size issue with "inceptionresnetv2"

When I am trying to run examples/ with Inception-Resnet-v2, i am getting the following error:

RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

Tried to run this:

import argparse
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from cnn_finetune import make_model

parser = argparse.ArgumentParser(description='Inception-Resnet-v2-TRAIN')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                    help='input batch size for testing (default: 64)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--save-model', type=int, default=10, metavar='N',
                    help='number of epochs after which the model will be saved (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model-name', type=str, default='resnet50', metavar='M',
                    help='model name (default: resnet50)')
parser.add_argument('--dropout-p', type=float, default=0.2, metavar='D',
                    help='Dropout probability (default: 0.2)')

args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

def train(model, epoch, optimizer, train_loader, criterion=nn.CrossEntropyLoss()):
    total_loss = 0
    total_size = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target =,
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()
        total_size += data.size(0)
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), total_loss / total_size))

def main():
    model_name = args.model_name

    classes = (
        'plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'

    model = make_model(
    model =

    transform = transforms.Compose([

    train_set = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform
    train_loader =
        train_set, batch_size=args.batch_size, shuffle=True, num_workers=2

    optimizer = optim.SGD(model.parameters(),, momentum=args.momentum)

    # Use exponential decay for fine-tuning optimizer
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.975)

    # Train
    for epoch in range(1, args.epochs + 1):
        train(model, epoch, optimizer, train_loader)
        if epoch % args.save_model == 0:
  , './checkpoint/' + 'ckpt_' + str(epoch) + '.pth')

if __name__ == '__main__':

illegal instruction (core dumped)

It seems that the default res-net works. Any model that needs to be downloaded (vgg, alexnet) cause this exception. Some examples below:

~/pytorch-cnn-finetune/examples$ python --model-name vgg19
Downloading: "" to /home/ubuntu/.torch/models/vgg19-dcbb9e9d.pth
100%|‚574673361/574673361 [00:20<00:00, 28293664.65it/s]
Illegal instruction (core dumped)

~/pytorch-cnn-finetune/examples$ python --model-name vgg16
Downloading: "" to /home/ubuntu/.torch/models/vgg16-397923af.pth
100%|553433881/553433881 [00:05<00:00, 98388620.83it/s]
Illegal instruction (core dumped)

Output size issue in alexnet and inception


File "/lib/python3.6/site-packages/cnn_finetune/", line 158, in calculate_classifier_in_features
output = original_model.features(input_var)

File "/lib/python3.6/site-packages/torch/nn/modules/", line 143, in forward

File "/lib/python3.6/site-packages/torch/nn/", line 334, in max_pool2d
ret = torch._C._nn.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

RuntimeError: Given input size: (256x1x1). Calculated output size: (256x0x0). Output size is too small at /opt/conda/conda-bld/pytorch_1518243271935/work/torch/lib/THNN/generic

File "/lib/python3.6/site-packages/torch/nn/", line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given input size: (320, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given input size: (288, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

Given input size: (384, 1, 1). Calculated output size: (100, 0, 0). Output size is too small.

NB1. For AlexNet, changing the train_original_classifier to True (it was False in all the above) gave the following error:
File "/lib/python3.6/site-packages/cnn_finetune/contrib/", line 93, in check_args
'For the original classifier '

Exception: For the original classifier input_size value must be (224, 224)

NB2. pre_trained_model = True in all the above

Some warnings due to updating PyTorch to version 0.4.0

After updating PyTorch to version 0.4.0, a few warnings appeared iin the Cifar10 example.
1- UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
test_loss += criterion(output, target).data[0]
corrected to: test_loss += criterion(output, target).item()
total_loss +=[0]
corrected to: total_loss += loss.item()

2- UserWarning: volatile was removed and now has no effect. Use with torch.no_grad(): instead.
data, target = Variable(data, volatile=True), Variable(target)
Note sure how to correct this.... with torch.no_grad():

mean and std used in the transform of the Cifar10 example

In the Cifar10 example, the transform uses the mean and std of the original pre-trained (on ImageNet) model(s) ; how would this transform affect the learning from new data? Shouldn't the mean and std obtained from the new data (Cifar10 in the example)?

transform = transforms.Compose([

How to replace last FC layer by 1x1 convolution?

Thanks for sharing a great API. I want to perform FCN (Fully convolutional network for semantic segmentation) using your API. It can be done by replacing the last Fully connected layer by 1x1 convolutional layer. Let's take resnet-18 as an example, how could I modify it to perform semantic segmentation. I think your classification example is good, and if you make an example for segmentation then it will be very good.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.