Giter VIP home page Giter VIP logo

neural-collapse's Introduction

A Geometric Analysis of Neural Collapse with Unconstrained Features

This is the code for the paper "A Geometric Analysis of Neural Collapse with Unconstrained Features".

Neural Information Processing Systems (NeurIPS), 2021

Introduction

  • We provide the first global optimization landscape analysis of Neural Collapse (NC) – an intriguing empirical phenomenon that arises in the last-layer classifiers and features of neural networks during the terminal phase of training.
  • We study the problem based on a simplified unconstrained feature model, which isolates the topmost layers from the classifier of the neural network. In this context, we show that the cross-entropy loss with weight decay has a benign global landscape: the only global minimizers are the Simplex Equiangular Tight Frames (ETFs) while all other critical points are strict saddles whose Hessian exhibit negative curvature directions.
  • Our experiments demonstrate that one may fix the last-layer classifier to be a Simplex ETF with d = K for network training, which reduces memory cost by over 20% on ResNet18 without sacrificing the generalization performance.

Environment

  • CUDA 11.0
  • python 3.8.3
  • torch 1.6.0
  • torchvision 0.7.0
  • scipy 1.5.2
  • numpy 1.19.1

Measuring NC during network training

Datasets

By default, the code assumes the datasets for MNIST and CIFAR10 are stored under ~/data/. If the datasets are not there, they will be automatically downloaded from torchvision.datasets. User may change this default location of datasets in args.py through the argument --data_dir.

Training with SGD

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --batch_size 256 --lr 0.05

Training with Adam

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer Adam --batch_size 64 --lr 0.001

Training with LBFGS

$ python train_2nd_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer LBFGS --lr 0.1 --history_size 10 --batch_size 2048

Note: For each epoch during training, a model will be saved under the directory model_weights/<the uid name fed to the above commands>/ for the purpose of validating the NC phenomenon in the future.

There are many other training options, e.g., --epochs, --weight_decay and so on, can be found in args.py.

Validating NC phenomenon

$ python validate_NC.py --gpu_id 0 --dataset <mnist or cifar10> --batch_size 256 --load_path <path to the uid name>

After training, by running the above command, we are able to calculate the four NC metrics defined in the paper. All the information of the NC metrics will be saved in an output file named info.pkl.

Finally, the evolutions of the NC metrics as well as the training/testing accuracy can be visualized by plotting them in figures:

$ python plot.py

Note: Please refer to plot.py for the details of plotting each figure in the paper.

Validating the unconstrained feature models for NC

Validity of unconstrained feature models

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset cifar10_random --optimizer SGD --batch_size 64 --lr 0.01 --model <MLP or ResNet18_adapt> --width <specify width for model> --depth <specify depth for MLP> --weight_decay 1e-4

$ python validate_NC.py --gpu_id 0 --dataset cifar10_random --batch_size 1000 --load_path <path to the uid name> --model <MLP or ResNet18_adapt> --width <specify width for model> --depth <specify depth for MLP>

Weight decay on the network parameter Θ vs. on the features H

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --batch_size 64 --lr 0.05 --model <specify model> --weight_decay <specify weight decay> --sep_decay --feature_decay_rate <specify weight decay on features>

Improving network design

Fix the last-layer classifier as a Simplex ETF

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --ETF

Feature dimension reduction by choosing d=K

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --fixdim 10

Introduce data augmentation and use modified ResNet architectures

$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --SOTA

Citation and reference

For technical details and full experimental results, please check our paper.

@article{zhu2021geometric,
      title={A Geometric Analysis of Neural Collapse with Unconstrained Features}, 
      author={Zhihui Zhu and Tianyu Ding and Jinxin Zhou and Xiao Li and Chong You and Jeremias Sulam and Qing Qu},
      year={2021},
      eprint={2105.02375},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

neural-collapse's People

Contributors

heimine avatar jinxinzhou avatar tding1 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.