Giter VIP home page Giter VIP logo

diverse-vit's Introduction

DiverseViT

Implementation of our paper:

Learning Diverse Features in Vision Transformers for Improved Generalization

Armand Mihai Nicolicioiu, Andrei Liviu Nicolicioiu, Bogdan Alexe, Damien Teney

ICML Workshop on Spurious Correlations, Invariance, and Stability (SCIS) 2023:

Installation

# Create new Python environment
conda create -n diverse-vit python=3.10
conda activate diverse-vit

# Install required libraries
pip install -r requirements.txt

Running

Run the training loop with default parameters:

cd scripts/
export PYTHONPATH=..
python main.py

The parameters are configurable using Hydra and can be overriden from the CLI.

python main.py diversification.weight=100 optimizer_params.lr=0.001 seed=42 <extra_args>

Make sure to also replace the data_path, logging_path, and checkpoints_path from the default config with your own.

Experiments

The reproduce our best results for both Empirical Risk Minimization and Diversification use the following overrides:

  • ERM: python main.py diversification.weight=0 optimizer_params.lr=0.0001
  • Diversification: python main.py diversification.weight=100 optimizer_params.lr=0.001

We use the following configuration for the Vision Transformer:

model: 'DiverseViT'
model_params:
  image_size: [64, 32]
  patch_size: 4
  num_classes: 2
  channels: 3
  dim: 64
  depth: 6
  heads: 4
  mlp_dim: 128

We provide checkpoints to evaluate the model by running:

CHECKPOINT_PATH=/home/armand/repos/diverse-vit/checkpoints/ckpt_diverse_ep37.pth
python main.py checkpoint=$CHECKPOINT_PATH

Additionaly, all the logs for 10 seeds for each experiment are in results_logs. To report the results summary, run the following script (after making sure you replace LOGS_PATH with your own):

cd scripts/
python reporting.py

The output should look like below:

Results sorted by ALL heads accuracy.

DiverseViT__adam-0.001__div-100.0   [ALL HEADS] 0.646 +- 0.017  [BEST HEAD] 0.704 +- 0.017
DiverseViT__adam-0.0001__div-0.0    [ALL HEADS] 0.627 +- 0.010  [BEST HEAD] 0.645 +- 0.030

Results sorted by BEST head accuracy.

DiverseViT__adam-0.001__div-100.0   [ALL HEADS] 0.646 +- 0.017  [BEST HEAD] 0.704 +- 0.017
DiverseViT__adam-0.0001__div-0.0    [ALL HEADS] 0.627 +- 0.010  [BEST HEAD] 0.645 +- 0.030

Citation

If our project is relevant for your research, please cite it using:

@incollection{nicolicioiu2023diversevit,
    title = {Learning Diverse Features in Vision Transformers for Improved Generalization},
    author = {Nicolicioiu, Armand Mihai and Nicolicioiu, Andrei Liviu and Alexe, Bogdan and Teney, Damien },
    booktitle = {ICML Workshop on Spurious Correlations, Invariance and Stability (SCIS)},
    year = {2023}
}

diverse-vit's People

Contributors

armandnm avatar andreinicolicioiu avatar

Stargazers

 avatar Quanyou Shen  avatar Zhuo Huang avatar Qin Lin avatar  avatar

Watchers

 avatar Damien Teney avatar  avatar Kostas Georgiou avatar

diverse-vit's Issues

Question about "head_indices"

Dear authors,

Thank you for your contributions.
I have a question regarding the network architecture. I believe the AttentionSelection function is meant to trim down the attention heads. However, I noticed that the head_indices parameter in the AttentionSelection function is set to None. Could you please provide an explanation on how to utilize this function effectively?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.