Giter VIP home page Giter VIP logo

Comments (5)

seuretm avatar seuretm commented on August 17, 2024 9

Hello,

TL;DR: the longer you train, the higher the result, but it's not a test result, but a validation result.

This code has a big issue, which I already raised: it uses the test set as validation set. What you get is the highest validation accuracy, not a test accuracy. As there is always a bit of randomness on the validation results, the longer you train, the higher the chance that due to this randomness you get a higher result.

Yet, this does not reflect what you would get in a never-seen test set, and any scientific publication using this methodology should be immediately retracted.

The right way of training the network would be to use some samples from the training set as validation, and use the test set only once, at the very end of the training - instead of using the test set in the optimization process.

from pytorch-cifar.

zaccharieramzi avatar zaccharieramzi commented on August 17, 2024 8

@skrbnv the differences pointed out in this comment are just the differences between the CIFAR-10 and the ImageNet versions.
In the original ResNet paper, there were indeed 2 different versions : one for CIFAR-10 (which has smaller res images) and one for ImageNet. The one for CIFAR-10 doesn't have striding in the first conv, and the initial MaxPooling.

If you use a ResNet with striding in the first conv, and the initial MaxPooling, you will not obtain accuracies above 91% (I tried). The main reason is that you lose too much information at the very beginning of the network.

The plain/residual architectures follow the form in Fig. 3
(middle/right). The network inputs are 32×32 images, with
the per-pixel mean subtracted. The first layer is 3×3 convolutions. Then we use a stack of 6n layers with 3×3 convolutions on the feature maps of sizes {32, 16, 8} respectively,
with 2n layers for each feature map size. The numbers of
filters are {16, 32, 64} respectively. The subsampling is performed by convolutions with a stride of 2. The network ends
with a global average pooling, a 10-way fully-connected
layer, and softmax. There are totally 6n+2 stacked weighted
layers

from pytorch-cifar.

AminJun avatar AminJun commented on August 17, 2024 4

@seuretm All points you mentioned are true. However, even if you don't use the test data during training, and only use it once (after the training is done), you get about 95% acc, where the repo reports 93% acc. So my experience is that the point @arashash is making here is valid.

from pytorch-cifar.

zaccharieramzi avatar zaccharieramzi commented on August 17, 2024 1

In a different setup, using a different implementation of the same model (ResNet-18 in CIFAR configuration), and a different code to perform the optimization, I also find that without using the test set during the evaluation I can reach 95.4 - 95.5% (I don't have a script to share since I am doing it as part of a bigger benchmark, but the gist of it is here). I am therefore concurring that this is not due to early stopping on the test set, or retaining the best test accuracy.

However it is also true that this specific setup (e.g. in particular the weight decay and learning rate values) might have been tuned on the test set, which would be problematic. But I think this is another topic.

As to why there is an improvement compared to the numbers reported in the repo, my guess is that with newer models, new training strategies were implemented and 93% is what you would get with an old strategy.
To back that guess a little bit, here is the version of main.py when the resnet (and the 93% result) was added.
You can see that the learning rate and the weight decay are much lower.

from pytorch-cifar.

skrbnv avatar skrbnv commented on August 17, 2024 1

This is not original ResNet18 network. That's the reason why accuracy is so high: #136 (comment)

from pytorch-cifar.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.