Giter VIP home page Giter VIP logo

realistic-ssl-evaluation's Introduction

realistic-ssl-evaluation

This repository contains the code for Realistic Evaluation of Deep Semi-Supervised Learning Algorithms, by Avital Oliver*, Augustus Odena*, Colin Raffel*, Ekin D. Cubuk, and Ian J. Goodfellow, arXiv preprint arXiv:1804.09170.

If you use the code in this repository for a published research project, please cite this paper.

The code is designed to run on Python 3 using the dependencies listed in requirements.txt. You can install the dependencies by running pip3 install -r requirements.txt.

The latest version of this repository can be found here.

Prepare datasets

For SVHN and CIFAR-10, we provide scripts to automatically download and preprocess the data. We also provide a script to create "label maps", which specify which entries of the dataset should be treated as labeled and unlabeled. Both of these scripts use an explicitly chosen random seed, so the same dataset order and label maps will be created each time. The random seeds can be overridden, for example to test robustness to different labeled splits. Run those scripts as follows:

python3 build_tfrecords.py --dataset_name=cifar10
python3 build_label_map.py --dataset_name=cifar10
python3 build_tfrecords.py --dataset_name=svhn
python3 build_label_map.py --dataset_name=svhn

For ImageNet 32x32 (only used in the fine-tuning experiment), you'll first need to download the 32x32 version of the ImageNet dataset by following the instructions here. Unzip the resulting files and put them in a directory called 'data/imagenet_32'. You'll then need to convert those files (which are pickle files) into .npy files. You can do this by executing:

mkdir data/imagenet_32
unzip Imagenet32_train.zip -d data/imagenet_32
unzip Imagenet32_val.zip -d data/imagenet_32
python3 convert_imagenet.py

Then you can build the TFRecord files like so:

python3 build_tfrecords.py --dataset_name=imagenet_32

ImageNet32x32 is the only dataset which must be downloaded manually, due to licensing issues.

Running experiments

All of the experiments in our paper are accompanied by a .yml file in runs/.These .yml files are intended to be used with tmuxp, which is a session manager for tmux. They essentially provide a simple way to create a tmux session with all of the relevant tasks running (model training and evaluation). The .yml files are named according to their corresponding figure/table/section in the paper. For example, if you want to run an experiment evaluating VAT with 500 labels as shown in Figure 3, you could run

tmuxp load runs/figure-3-svhn-500-vat.yml

Of course, you can also run the code without using tmuxp. Each .yml file specifies the commands needed for running each experiment. For example, the file listed above runs/figure-3-svhn-500-vat.yml runs

CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='svhn' --secondary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/figure-3-svhn-500-vat --n_labeled=500 --consistency_model=vat --hparam_string=""  2>&1 | tee /mnt/experiment-logs/figure-3-svhn-500-vat_train.log
CUDA_VISIBLE_DEVICES=1 python3 evaluate_model.py --split=test --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/figure-3-svhn-500-vat --consistency_model=vat --hparam_string=""  2>&1 | tee /mnt/experiment-logs/figure-3-svhn-500-vat_eval_test.log
CUDA_VISIBLE_DEVICES=2 python3 evaluate_model.py --split=valid --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/figure-3-svhn-500-vat --consistency_model=vat --hparam_string=""  2>&1 | tee /mnt/experiment-logs/figure-3-svhn-500-vat_eval_valid.log
CUDA_VISIBLE_DEVICES=3 python3 evaluate_model.py --split=train --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/figure-3-svhn-500-vat --consistency_model=vat --hparam_string=""  2>&1 | tee /mnt/experiment-logs/figure-3-svhn-500-vat_eval_train.log

Note that these commands are formulated to write out results to /mnt/experiment-logs. You will either need to create this directory or modify them to write to a different directory. Further, the .yml files are written to assume that this source tree lives in /root/realistic-ssl-evaluation.

A note on reproducibility

While the focus of our paper is reproducibility, ultimately exact comparison to the results in our paper will be conflated by subtle differences such as the version of TensorFlow used, random seeds, etc. In other words, simply copying the numbers stated in our paper may not provide a means for reliable comparison. As a result, if you'd like to use our implementation of baseline methods as a point of comparison for e.g. a new semi-supervised learning technique, we'd recommend re-running our experiments from scratch in the same environment as your new technique.

Simulating small validation sets

The following command runs evaluation on a set of checkpoints, with multiple resamples of small validation sets (as in figure 5 in the paper):

python3 evaluate_checkpoints.py --primary_dataset_name='cifar10' --checkpoints='/mnt/experiment-logs/section-4-3-cifar-fine-tuning/default/model.ckpt-1000000,/mnt/.../model.ckpt-...,...'

Results are printed to stdout for each evaluation run, and at the end a string representation of the entire list of validation accuracies for each resampled validation set and each checkpoint is printed:

{'/mnt/experiment-logs/table-1-svhn-1000-pi-model-run-5/default/model.ckpt-500001': [0.86, 0.93, 0.92, 0.91, 0.9, 0.94, 0.91, 0.88, 0.88, 0.89]}

Disclaimer

This is not an official Google product.

realistic-ssl-evaluation's People

Contributors

alexandra-zaharia avatar avital avatar craffel avatar doctorkey avatar joschu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

realistic-ssl-evaluation's Issues

where is data augmentation performed?

Hi, thanks for your inspiring work "Realistic Evaluation of Deep Semi-Supervised Learning Algorithms". In the paper Appendix B you mentioned data augmentation is performed. When I look at the code in the repo, I couldn't find where they are performed. I'm sorry if I overlook anything, but I would really appreciate if you can point it out.

Tabular results for Fig.4

Hi, thanks for the great job. Do you have tabular results for Fig.4, that is the exact mean error and variance for different algorithms under different number of labels. Thanks.

Randomness in build_tfrecords.py + label maps

Hi @avital , first of all thanks for releasing the code, I find it very elegant, and tmuxp sessions for every experiment are really cool!

There is information in the README that in order to have exactly the same labeled/unlabeled divisions as in the paper, one has to use label maps from the repo. But the issue is, in order to run the experiments one has to first build tfrecords, which uses random shuffle. This means we HAVE TO create new label maps for these newly created tfrecords.
Am I missing something, or there is really an issue here? How can we generate exactly the same tfrecords as in your experiments?

VAT GPU vRAM usage

I am trying to run the experiment associated with runs/figure-2-cifar10-4000-vat-ol0.yml, but the GPU (2080ti) appears to run out of memory (11g). This doesn't appear to be an issue of batch size as the CIFAR-10 images are fairly small, and I don't think the model would take up that much space either. I've tried with Tensorflow version 1.14 and 1.15 with no luck. Any suggestions?

Why there is no dropout in the model???

Hi, thanks for your work,
but I'm concerned about why there is no dropout in your model? Dropout improves the performance of consistent loss largely as a data augmentation policy. Is this why the performance of the Pi model is so different from that of the VAT?
Thank you

consistency_model=mean_teacher, when you do fully supervised learning

Hi,

I saw that you're using consistency model=mean_teacher even in cases where you are doing fully supervised learning. In fact, if we look at the yml files:

table-1-cifar10-4000-fullysup.yml and table-1-cifar10-4000-mean-teacher.yml

we see that the only difference between them is on:

--hparam_string="max_cons_multiplier=0" for fully_supervised and
--hparam_string="" for the mean teacher.

Am I right on this assumption, and can you please explain me how does the model needs where to do fully_supervised vs mean_teacher?

Learning rate for fully supervised baseline is wrong

The flags for the fully-supervised baselines choose mean_teacher for the consistency model, intended as a no-op because the consistency weight is set to zero. But the hparams for Mean Teacher dictate a different learning rate than the best one we found for fully-supervised (and also the one we state we used in the paper).

Code Release

Subscribe to this issue to get notified when the code is released and pushed to this repository.

Results in the paper for low-regime data

Hi guys,

I've been trying to replicate your results. When I am training the net with 500 labeled data points, I saw some weird behavior.

screenshot from 2018-11-19 09-26-34

As you can see from the tensorboard, the accuracy of the net increases fast, achieves a peak at around 70k iterations, and then it falls to 0.1 which is random. The results here are given for pi-model, but I saw the same behavior also with mean teacher.

My questions are:

  1. Is this the expected behavior (did you also see this phenomenon)?
  2. If yes, then how do you get the numbers for the paper? I assume that you look for the highest peak in the validation set, and then with that checkpoint you do a testing in the testing set. Am I right in this assumption? (I also saw that one of the saved nets is the one which gives the best result so thus my assumption).
  3. Why do you do so many tests in the testing set, instead of just one test with the net which gives the best results in the validation set?

Cheers!

VAT implementation is wrong

Thank you for your code.

Following the original VAT paper, consistency_func in hparams.py should be reverse_kl for VAT, although it is set to forward_kl in your code.

The adversarial noise r in VAT is obtained by maximizing D_KL(p(y|x)||p(y|x+r)), however, the consistency loss D_KL(p(y|x+r)||p(y|x)) is used when consistency_func=forward_kl. It matters because of the asymmetricity of KL divergence, I think.

It might use unlabeled data to train "fullysup"

Thank you for your code!
When I run the code by using table-1-cifar10-4000-fullysup.yml, I find it might use unlabeled data to train "fullysup". A batch will contain labeled data and unlabeled data. Since the Wide ResNet has BN layer, it will use unlabeled data to compute the variables of BN.

Why use WRN-28-2?

Hi, I wonder why you use WRN-28-2 instead of standard resent and widely used 10 layer ConvNet in ฮ -Model, Mean Teacher, and VAT.
Thanks!

Is there a way to download ImageNet weights for WRN-28-2?

I see that the repository provides a method to download ImageNet in a 32x32 format, however, would you have access to the weights for the WRN-28-2 model that was trained on ImageNet and used in the transfer learning section of the paper?

I would like to use the ImageNet weights in another application, but if you had a link or another method for me to download them it would save much time and cost in training the network on ImageNet from scratch.

Thanks!

How is the error rate and uncertainty calculated?

Hello,

From reading the paper, it seems that the error rates are calculated from the test error chosen at the point of lowest validation error, is that correct?

Also, do you perform experiments with multiple random seeds, or is uncertainty calculated using several different validation sets? How exactly was that uncertainty calculated? Forgive me if I missed it as a detail in your paper.

Thank you!

Script for downloading and preprocessing CIFAR and svhn dataset

Hi thanks a lot for your work! In the readme you said you have the script for downloading and preprocessing CIFAR and svhn dataset, but I didn't find it in the repository. Could you please describe how you preprocess the data for training?

Thank you!

why this implementation of pseudo-labeling

Hi
I'm wondering in your implementation of pseudo-labeling, why did you use non-zero loss for unlabeled sample that has maximum predicted probability below the threshold.

Why not use original pseudo-label-implement as baseline?

@avital In the Pseudo Label original paper, it says "we just pick up the class which has maximum predicted probability for each unlabeled sample" and the loss is defined as the weight average of the label loss and the unlabel loss.

Could you please say something about why this code doesn't use what original paper claims as baseline but uses another implementation with "teacher-student-like" loss and pick the label when confident is greater than a threshold as baseline.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.