Giter VIP home page Giter VIP logo

bayesian-sparse-deep-learning's Introduction

Bayesian Sparse Deep Learning

Experiment code for "An Adaptive Empirical Bayesian Method for Sparse Deep Learning". We propose a novel adaptive empirical Bayesian method to efficiently train hierarchical Bayesian mixture DNN models, where the parameters are learned through sampling while the priors are learned through optimization. In addition, this model can be further generalized to a class of adaptive sampling algorithms for estimating various state-space models in deep learning.

@inproceedings{deng2019,
  title={An Adaptive Empirical Bayesian Method for Sparse Deep Learning},
  author={Wei Deng and Xiao Zhang and Faming Liang and Guang Lin},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

Large-p-small-n Linear Regression

GitHub Logo

For the ground truth calculation of the standard deviation, you may check a reference here link

Regression: UCI dataset

Requirement

Since the model is simple, GPU environment doesn't give you significant computational accelerations. Therefore, we use CPU instead.

The followings are the commands to run the different methods (stochastic approximation SGHMC/EM-SGHMC/vanilla SGHMC) on the Boston housing price dataset.

python uci_run.py -data boston -c sa -invT 1 -v0 0.1 -anneal 1.003 -seed 5
python uci_run.py -data boston -c em -invT 1 -v0 0.1 -anneal 1.003 -seed 5
python uci_run.py -data boston -c sghmc -invT 1 -v0 0.1 -anneal 1.003 -seed 5

You can also use the other datasets to test the performance, e.g. yacht, energy-efficiency, wine and concrete. To obtain a comprehensive evaluation, you may need to try many different seeds.

Classification: MNIST/Fashion MNIST

You can adjust the posterior_cnn.py and use the model in ./model/model_zoo_mnist.py. 99.7x% results on MNIST dataset can be easily obtained with the hyperparameters (most importantly: temperature) in the paper. To run the Adversarial examples, you can include the file in tools/attacker.py and make the corresponding changes.

Classification: Sparse Residual Network on CIFAR10

Requirement

Pretrain a dense model

python bayes_cnn.py -lr 2e-6 -invT 20000 -save 1 -prune 0  

Finetune a sparse model through stochastic approximation

python bayes_cnn.py -lr 2e-9 -invT 1000 -anneal 1.005 -v0 0.005 -v1 1e-5 -sparse 0.9 -c sa -prune 1

The default code can produce a 90%-sparsity Resnet20 model with the state-of-the-art 91.56% accuracy (PyTorch version 1.01) based on 27K parameters, by contrast, EM-based SGHMC (with step size 1) and vanilla SGHMC algorithm (with step size 0) obtain much worse results. The running log based on the default seed is saved in the output folder, you can try other seeds to obtain higher results.

For other sparse rates, one needs to tune the best v0 and v1 parameters, e.g. 30%: (0.5, 1e-3), 50%: (0.1, 5e-4), 70%: (0.1, 5e-5), to achive the state-of-the-art.

Further comments

For DNN modls, we suggest to set the sparsity and sigma^2 as fixed and no longer update the corresponding iterations (14-15).

References:

Wei Deng, Xiao Zhang, Faming Liang, Guang Lin, An Adaptive Empirical Bayesian Method for Sparse Deep Learning, NeurIPS, 2019

bayesian-sparse-deep-learning's People

Contributors

anonymousauthor0506 avatar waynedw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bayesian-sparse-deep-learning's Issues

Question about resuming training

Hi, thanks for your interesting paper. I meet some puzzles and ask for your help. I am fine-tuning a sparse model through stochastic approximation and the training process is stopped at 459th epoch due to certain cause. I reload the saved fine-tuned model at its break point (459th epoch) but the initial calculated model sparse rate is 0.59% and is different from that of stopped epoch (model sparse rate is 45%). Logically, the most parameters of the trained model have been set to 0 during past training, but why the calculated sparse rate of the loaded model is 0.59% instead of 45% at loaded 459th epoch? I would be of great appreciation if you could help me with the puzzle. Thank you!

Question about sparsity and pruning of the model

    Hello, I met some confusions after reading your paper, hope to get the answers, thank you very much!
   1. The "pruning" operation in the code seems to be just setting the weight parameter to zero without actually deleting the network neurons. So if I save a model with 90% sparse 27K parameters, the actual size of the model should be the same as that of the non sparse model, right? So, there is no way to get a smaller model with only 27K parameters here.
   2. If we just set the weight parameter to zero and do not delete the corresponding neurons, then these neurons will still exist when the next model parameters are updated, and won't it  cause these "zeroed" neurons to regain non-zero values?
   The puzzle is that if the neurons are not deleted during pruning, will the neurons that have been set to zero in the next training get parameters again and take effect? For example, if the weight of some neurons is set to zero, what we hope is that these neurons will lose their function and will not be updated in the next training, right? What puzzles me is that these zeroed neurons have not been deleted, and won't it continue to take effect due to the new weight value obtained by the model updating process?
    I hope you will forgive me if there are any unconscious offenses. I look forward to hearing from you! Thank you very much again.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.