Giter VIP home page Giter VIP logo

hat's Introduction

Helper-based Adversarial Training

This repository contains the code for the ICLR 2022 paper "Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off" by Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli.

A short version of the paper has been accepted for Oral presentation at the ICML 2021 Workshop on A Blessing in Disguise: The Prospects and Perils of Adversarial Machine Learning and can be found at this link.

Setup

Requirements

Our code has been implemented and tested with Python 3.8.5 and PyTorch 1.8.0. To install the required packages:

$ pip install -r requirements.txt

Repository Structure

.
└── core             # Source code for the experiments
    ├── attacks            # Adversarial attacks
    ├── data               # Data setup and loading
    ├── models             # Model architectures
    └── utils              # Helpers, training and testing functions
    └── metrics.py         # Evaluation metrics
└── train.py         # Training script
└── train-wa.py      # Training with model weight averaging
└── eval-aa.py       # AutoAttack evaluation
└── eval-adv.py      # PGD+ and CW evaluation
└── eval-rb.py       # RobustBench evaluation

Usage

Training

Run train.py for standard, adversarial, TRADES, MART and HAT training. Example commands for HAT training are provided below:

First, train a ResNet-18 model on CIFAR-10 with standard training:

$ python train.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc std-cifar10 \
    --data cifar10 \
    --model resnet18 \
    --num-std-epochs 50

Then, run the following command to perform helper-based adversarial training (HAT) on CIFAR-10:

$ python train.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc hat-cifar10 \
    --data cifar10 \
    --model resnet18 \
    --num-adv-epochs 50 \
    --helper-model std-cifar10 \
    --beta 2.5 \
    --gamma 0.5

Robustness Evaluation

The trained models can be evaluated by running eval-aa.py which uses AutoAttack for evaluating the robust accuracy. For example:

$ python eval-aa.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc hat-cifar10

For evaluation with PGD+ and CW attacks, use:

$ python eval-adv.py --wb --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc hat-cifar10

Incorporating Improvements from Gowal et al., 2020 & Rebuffi et al., 2021

HAT can be combined with imporvements from the papers "Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples" (Gowal et al., 2020) and "Fixing Data Augmentation to Improve Adversarial Robustness" (Rebuffi et al., 2021) to obtain state-of-the-art performance on multiple datasets.

Training a Standard Network for Computing Helper Labels

Train a model with standard training as mentioned above or alternatively download the appropriate pretrained model from this link and place the contents of the corresponding zip file in the directory <log_dir>.

HAT Training

Run train-wa.py for training a robust network via HAT. For example, to train a WideResNet-28-10 model via HAT on CIFAR-10 with the additional pseudolabeled data provided by Carmon et al., 2019 or the generated datasets provided by Rebuffi et al., 2021, use the following command:

$ python train-wa.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc <name_of_the_experiment> \
    --data cifar10s \
    --batch-size 1024 \
    --batch-size-validation 512 \
    --model wrn-28-10-swish \
    --num-adv-epochs 400 \
    --lr 0.4 --tau 0.995 \
    --label-smoothing 0.1 \
    --unsup-fraction 0.7 \
    --aux-data-filename <path_to_additional_data> \
    --helper-model <helper_model_log_dir_name> \
    --beta 3.5 \
    --gamma 0.5

Results

Below, we provide the results with HAT. In the settings with additional data, we follow the experimental setup used in Gowal et al., 2020 and Rebuffi et al., 2021. Whereas we resort to the experimental setup provided in our paper when not using additional data. Our pretrained models are available via RobustBench.

With extra data from Carmon et al., 2019 along with the improvements by Gowal et al. 2020

Dataset Norm ε Model Clean Acc. Robust Acc.
CIFAR-10 8/255 PreActResNet-18 89.02 57.67
CIFAR-10 8/255 WideResNet-28-10 91.30 62.50
CIFAR-10 8/255 WideResNet-34-10 91.47 62.83

Our models achieve around ~0.3-0.5% lower robustness than that reported in Gowal et al., 2020 since they use a custom pseudolabeled dataset which is not publicly available (See Section 4.3.1 here).

With synthetic DDPM generated data from Rebuffi et al., 2021

Dataset Norm ε Model CutMix Clean Acc. Robust Acc.
CIFAR-10 8/255 PreActResNet-18 86.86 57.09
CIFAR-10 8/255 WideResNet-28-10 88.16 60.97
CIFAR-10 2 128/255 PreActResNet-18 90.57 76.07
CIFAR-100 8/255 PreActResNet-18 61.50 28.88
CIFAR-100 8/255 WideResNet-34-10 62.21 31.16

Without additional data

Dataset Norm ε Model Clean Acc. Robust Acc.
CIFAR-10 8/255 ResNet-18 84.90 49.08
CIFAR-10 12/255 ResNet-18 79.30 33.47
SVHN 8/255 ResNet-18 93.08 52.83
TI-200 8/255 PreActResNet-18 52.60 18.14

Citing this work

@inproceedings{
    rade2022reducing,
    title={Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off},
    author={Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=Azh9QBQ4tR7}
}

hat's People

Contributors

imrahulr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

hat's Issues

PermissionError: [WinError 32]

My OS is windows 10.
I tried to run the train.py with 'python train.py --desc std-cifar10' but I got an error:
PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: .....
Apparently, the error is related to shutil.rmtree(LOG_DIR). But I can't solve it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.