Giter VIP home page Giter VIP logo

augmax's Introduction

AugMax: Adversarial Composition of Random Augmentations for Robust Training

License: MIT

Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Anima Anandkumar, and Zhangyang Wang

In NeurIPS 2021

Overview

We propose AugMax, a data augmentation framework to unify the diversity and hardness. Being a stronger form of data augmentation, AugMax leads to a significantly augmented input distribution which makes model training more challenging. To solve this problem, we further design a disentangled normalization module, termed DuBIN (Dual-Batch-and-Instance Normalization) that disentangles the instance-wise feature heterogeneity arising from AugMax. AugMax-DuBIN leads to significantly improved out-of-distribution robustness, outperforming prior arts by 3.03%, 3.49%, 1.82% and 0.71% on CIFAR10-C, CIFAR100-C, Tiny ImageNet-C and ImageNet-C.

AugMax
AugMax achieves a unification between hard and diverse training samples.

results
AugMax achieves state-fo-the-art performance on CIFAR10-C, CIFAR100-C, Tiny ImageNet-C and ImageNet-C.

Training

Assume all datasets are stored in <data_root_path>. For example, CIFAR-10 is in <data_root_path>/cifar-10-batches-py/ and ImageNet training set is in <data_root_path>/imagenet/train.

AugMax-DuBIN training on <dataset> with <backbone> (The outputs will be saved in <save_root_path>):

python augmax_training_ddp.py --gpu 0 --srp <where_you_save_the_outputs> --drp <data_root_path> --ds <dataset> --md <backbone> --Lambda <lambda_value> --steps <inner_max_step_number>

For example:

AugMax-DuBIN on CIFAR10 with ResNeXt29 (By default we use Lambda=10 on CIFAR10/100 and Tiny ImageNet.):

python augmax_training_ddp.py --gpu 0 --drp /ssd1/haotao/datasets --ds cifar10 --md ResNeXt29 --Lambda 10 --steps 10

AugMax-DuBIN on CIFAR100 with ResNet18 (We use Lambda=1 instead of Lambda=10 in this particular experiment, as noted in the paper.):

python augmax_training_ddp.py --gpu 0 --drp /ssd1/haotao/datasets --ds cifar100 --md ResNet18 --Lambda 1 --steps 10

AugMax-DuBIN on ImageNet with ResNet18 (By default we use Lambda=12 on ImageNet. On ImageNet, weight decay wd=1e-4 instead of the default value in the code, which is for CIFAR-level datasets.):

NCCL_P2P_DISABLE=1 python augmax_training_ddp.py --gpu 0 --drp /ssd1/haotao/datasets --ds IN --md ResNet18 --Lambda 12 -e 90 --wd 1e-4 --decay multisteps --de 30 60 --ddp --dist_url tcp://localhost:23456

AugMax-DuBIN + DeepAug on ImageNet with ResNet18:

NCCL_P2P_DISABLE=1 python augmax_training_ddp.py --gpu 0 --drp /ssd1/haotao/datasets --ds IN --md ResNet18 --deepaug --Lambda 12 -e 30 --wd 1e-4 --decay multisteps --de 10 20 --ddp --dist_url tcp://localhost:23456

Pretrained models

The pretrained models are available on Google Drive.

Testing

To test the model trained on <dataset> with <backbone> and saved to <ckpt_path>:

python test.py --gpu 0 --ds <dataset> --drp <data_root_path> --md <backbone> --mode all --ckpt_path <ckpt_path>

For example:

python test.py --gpu 0 --ds cifar10 --drp /ssd1/haotao/datasets --md ResNet18_DuBIN --mode all --ckpt_path augmax_training/cifar10/ResNet18_DuBIN/fat-1-untargeted-5-0.1_Lambda10.0_e200-b256_sgd-lr0.1-m0.9-wd0.0005_cos

Citation

@inproceedings{wang2021augmax,
  title={AugMax: Adversarial Composition of Random Augmentations for Robust Training},
  author={Wang, Haotao and Xiao, Chaowei and Kossaifi, Jean and Yu, Zhiding and Anandkumar, Anima and Wang, Zhangyang},
  booktitle={NeurIPS},
  year={2021}
}

augmax's People

Contributors

htwang14 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

augmax's Issues

Missing package - resnet_DuBN

Hi, first of all congratulations on your acceptance to NeurIPS 2021, I was checking the code and I saw that you are referencing the package models.cifar10.resnet_DuBN that is not in the repository.

Could you please add it to make the code reproducible?

Use custom dataset

Hello,
Thank you for this code it is really helpful. I have a question is it possible to use the same code with custom dataset?
If so please tell me what to change.
Best regards

missing L(f(x ∗ ); θ), y)

in the paper the total loss(equation 5) - 1/2[L(f(x∗); θ), y) + L(f(x); θ), y)] + λLc(x, x∗)
but in the implementation you only use L(f(x); θ), y) + λLc(x, x∗). which is the correct equation?

I train ResNet18 in Cifar100

I train ResNet18 in Cifar100 with augmax and without augmax but i find that i can got 77.5% accuracy without augmax ,while i can only got 76.3% with augmax, can u give some explainations?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.