Giter VIP home page Giter VIP logo

somepago / dbviz Goto Github PK

View Code? Open in Web Editor NEW
72.0 3.0 11.0 5.21 MB

The official PyTorch implementation - Can Neural Nets Learn the Same Model Twice? Investigating Reproducibility and Double Descent from the Decision Boundary Perspective (CVPR'22).

License: Apache License 2.0

Python 96.29% Shell 3.71%
pytorch understanding-neural-networks decision-boundaries visualization-tools neural-network-visualizations decision-boundary-visualizations cifar10

dbviz's Introduction

Can Neural Nets Learn the Same Model Twice? Investigating Reproducibility and Double Descent from the Decision Boundary Perspective

To appear in CVPR 2022 (Oral). Check out the arxiv version here

alt text

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

Create a conda environment from the yml file and activate it.

conda env create -f environment.yml
conda activate dbviz_env

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

We used wandb to log most of the outputs.

conda install -c conda-forge wandb 

Training the model and plotting the decision boundary

We provide wide variety of models, please see models folder to see the exhasutive list. Train the model with following command.

python main.py --net <model_name> --set_seed <init_seed> --save_net <model_save_path> --imgs 500,5000,1600 --resolution 500 --active_log --epochs <number_epochs> --lr <suitable_learningrate>

Reproducibility experiments

Once you have a saved model, we save prediction arrays for this model by running the following command:

python save_preds.py --load_net /path/to/your/saved/models --epochs 500 --resolution 50

Here, epochs is a stand-in for number of runs, and resolution determines the grid resolution for sampling.

Then, we calculate the reproducibility matrix by running the following command:

python calculate_iou.py --load_net /path/to/your/saved/models 

These two scripts also contain more information about the structure of saved models that is needed.

To recreate the plots from the paper, first train each architecture atleast 3 times with different initilaization seed (like 0,1,2). Then run the following bash file.

bash script_plots.sh

Double Descent experiments

To reproduce double descent experiments, please refer to the README file in the double-descent folder.

Acknowledgements

We would like to thank the following public repos from which we borrowed model training utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2022can,
  title={Can Neural Nets Learn the Same Model Twice? Investigating Reproducibility and Double Descent from the Decision Boundary Perspective},
  author={Somepalli, Gowthami and Fowl, Liam and Bansal, Arpit and Yeh-Chiang, Ping and Dar, Yehuda and Baraniuk, Richard and Goldblum, Micah and Goldstein, Tom},
  journal={arXiv preprint arXiv:2203.08124},
  year={2022}
}

dbviz's People

Contributors

lhfowl avatar ping-c avatar somepago avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.