Giter VIP home page Giter VIP logo

datahardness's Introduction

Evaluating State-of-the-Art Classification Models Against Bayes Optimality

Introduction

This is the code repo to reproduce the experiments in the paper Evaluating State-of-the-Art Classification Models Against Bayes Optimality. This repo is based on Glow in Pytorch, and LinConGauss.

Prerequisite

The Bayes error computation procedure is based on LinConGauss:

git clone https://github.com/alpiges/LinConGauss.git ~/LinConGauss
cd ~/LinConGauss
python setup.py install

Train a Conditional Glow Model

python train.py --dataset CIFAR10 --lambda 10 --output_dir cifar10-ckpts --sample_dir cifar10-samples

Below we explain the command line arguments one by one:

--dataset

Currently valid choices for the command line argument --dataset are:

  • MNIST
  • SVHN
  • CIFAR10
  • CIFAR100
  • FashioMNIST

--alpha :

Following the original Glow paper, we also add a classification loss to predict the class labels from the second-to-last layer of the encoder with a weight of alpha. (in the paper this weight is denoted as λ) . Note here even though we add the classification loss in the objective as a regularizer, the model is selected based on the smallest NLL loss in the test set instead of the classification loss or the total loss.

--output_dir

The directory to save the trained model. There are two versions of the model being saved in this folder:

best.pth.tar : this is the model that achieves the smallest NLL loss on the test dataset. Note here the NLL loss does NOT include the classification loss in the objective.

latest.pth.tar: this is the latest model produced by the last epoch in the training.

--sample_dir

The directory to save a set of images sampled from the model after each epoch.

The training script is based on Glow in Pytorch. We use the default hyperparameters in the Glow model across all datasets:

'--num_channels', '-C', default=512
'--num_levels', '-L', default=3
'--num_steps', '-K', default=16

Compute the Bayes Error

The following script will extract the dataset information from the trained Glow model and compute the (exact) Bayes error of the dataset generated by the Glow model.

python compute_bayes_error.py --model cifar10-ckpts/best.pth.tar

Generate datasets from the Conditional Glow Model

With a trained Glow model we may now generate as many samples as we want and construct a new dataset:

python generate_dataset.py --model_path cifar10-ckpts/best.pth.tar --batch_sz 100 --n_batches 700 --temperature 1.0 --save_fp saved_datasets/cifar10-gen.h

Technical Details and Citations:

You can find more details in the paper:

Evaluating State-of-the-Art Classification Models Against Bayes Optimality

If you're using this repo in your research or applications, please cite using this BibTeX:

@article{theisen2021evaluating,
      title={Evaluating State-of-the-Art Classification Models Against Bayes Optimality}, 
      author={Ryan Theisen and Huan Wang and Lav R. Varshney and Caiming Xiong and Richard Socher},
      year={2021},
      eprint={2106.03357},
      archivePrefix={arXiv},
      primaryClass={stat.ML}
}

References:

Evaluating State-of-the-Art Classification Models Against Bayes Optimality, by Ryan Theisen, Huan Wang, Lav R Varshney, Caiming Xiong, and Richard Socher. NeurIPS, 2021.

The Glow training on pytorch is based on Glow in Pytorch

The compute Bayes Error procedure is based on LinConGauss.

datahardness's People

Contributors

huan-dec avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.