Giter VIP home page Giter VIP logo

vadam's Introduction

Vadam

Keras optimizer that modifies the Adam optimizer to approximate variational inference by perturbing weights following arXiv 1712.07628.

Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., 
& Srivastava, A. (2018). Fast and scalable bayesian deep 
learning by weight-perturbation in adam. 
arXiv preprint arXiv:1806.04854.
  • This optimizer supports Keras 2.3.1 since the Tensorflow 2.0 version of Adam separates gradients by sparsity, and this algorithm does not support sparse gradients according to the authors' Pytorch implementation. Here is one person's workaround to that issue.
  • The default prior precision value (Lambda in the paper) results in a completely uninformative prior that will NOT yield viable results by the authors' own admission (appendix K.3). According to the relevant section of the paper, finding the right value of Lambda is beyond the scope of the paper but an example Hyperas script that tunes Vadam simultaneously on learning rate and prior precision is offered here to address this.
  • This version of the Vadam algorithm follows slides 11 of 15 from the 2018 ICML presentation slides, which is slightly different from the paper. In this implementation of the version of the algorithm from the slides, only the epsilon fuzz factor is added to parameter updates instead of mean and standard deviations derived from a diagonal multivariate gaussian distribution, though those may be added in the future.
  • The Pytorch version of Vadam also includes the ability to provide Monte Carlo sampling to parameter updates, which is not included here. However, the ablation tests in appendix I.2 uses 1 Monte Carlo sample so this simplification may not adversely affect variation too badly. See here for more information on the presentation.
  • Unlike the Keras implementation of Adam, both the Pytorch implementation of Adam as well as Vadam perform bias correction. Bias correction is therefore added here as well, but using it resulted in numerical instability and code is left commented out.
  • The Adagrad option is removed since it is not in the Pytorch implementation.

Usage (only required parameter is train_set_size, though prior_prec should definitely be tuned):

import numpy as np
X_train = np.random.random((1000, 32))
Y_train = np.random.random((1000, 10))

model = Sequential()
...    
model.compile(optimizer=Vadam(train_set_size=1000,
                      ...)

# train_set_size parameter is from X_train

result = model.fit(X_train,
                   Y_train,
                   ...)

Only works with Tensorflow < 2.0 for the reason described above, and this version only works with Keras < 2.3.1.

This optimizer is suitable for approximating variational inference in a neural network to provide probablistic output that provides upper and lower confidence bounds on prediction.

MIT License

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.