Evaluating State-of-the-Art Classification Models Against Bayes Optimality
Introduction
This is the code repo to reproduce the experiments in the paper Evaluating State-of-the-Art Classification Models Against Bayes Optimality. This repo is based on Glow in Pytorch, and LinConGauss.
Prerequisite
The Bayes error computation procedure is based on LinConGauss:
git clone https://github.com/alpiges/LinConGauss.git ~/LinConGauss
cd ~/LinConGauss
python setup.py install
Train a Conditional Glow Model
python train.py --dataset CIFAR10 --lambda 10 --output_dir cifar10-ckpts --sample_dir cifar10-samples
Below we explain the command line arguments one by one:
--dataset
Currently valid choices for the command line argument --dataset are:
- MNIST
- SVHN
- CIFAR10
- CIFAR100
- FashioMNIST
--alpha :
Following the original Glow paper, we also add a classification loss to predict the class labels from the second-to-last layer of the encoder with a weight of alpha. (in the paper this weight is denoted as λ) . Note here even though we add the classification loss in the objective as a regularizer, the model is selected based on the smallest NLL loss in the test set instead of the classification loss or the total loss.
--output_dir
The directory to save the trained model. There are two versions of the model being saved in this folder:
best.pth.tar : this is the model that achieves the smallest NLL loss on the test dataset. Note here the NLL loss does NOT include the classification loss in the objective.
latest.pth.tar: this is the latest model produced by the last epoch in the training.
--sample_dir
The directory to save a set of images sampled from the model after each epoch.
The training script is based on Glow in Pytorch. We use the default hyperparameters in the Glow model across all datasets:
'--num_channels', '-C', default=512
'--num_levels', '-L', default=3
'--num_steps', '-K', default=16
Compute the Bayes Error
The following script will extract the dataset information from the trained Glow model and compute the (exact) Bayes error of the dataset generated by the Glow model.
python compute_bayes_error.py --model cifar10-ckpts/best.pth.tar
Generate datasets from the Conditional Glow Model
With a trained Glow model we may now generate as many samples as we want and construct a new dataset:
python generate_dataset.py --model_path cifar10-ckpts/best.pth.tar --batch_sz 100 --n_batches 700 --temperature 1.0 --save_fp saved_datasets/cifar10-gen.h
Technical Details and Citations:
You can find more details in the paper:
Evaluating State-of-the-Art Classification Models Against Bayes Optimality
If you're using this repo in your research or applications, please cite using this BibTeX:
@article{theisen2021evaluating,
title={Evaluating State-of-the-Art Classification Models Against Bayes Optimality},
author={Ryan Theisen and Huan Wang and Lav R. Varshney and Caiming Xiong and Richard Socher},
year={2021},
eprint={2106.03357},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
References:
Evaluating State-of-the-Art Classification Models Against Bayes Optimality, by Ryan Theisen, Huan Wang, Lav R Varshney, Caiming Xiong, and Richard Socher. NeurIPS, 2021.
The Glow training on pytorch is based on Glow in Pytorch
The compute Bayes Error procedure is based on LinConGauss.