Giter VIP home page Giter VIP logo

alrao's Introduction

Learning with Random Learning Rates

Authors' implementation of "Learning with Random Learning Rates" (2018) in PyTorch.

The original paper can be found here: arxiv:1810.01322, and a shorter blog post explaining the method here: leonardblier.github.io/alrao

Authors: Léonard Blier, Pierre Wolinski, Yann Ollivier.

Requirements

The requirements are:

  • pytorch (0.4 and higher)
  • numpy, scipy
  • tqdm

Tutorial

A tutorial on how to use Alrao with custom models is in tutorial.ipynb.

Sample script for using Alrao with convolutional models on CIFAR10

The script main_cnn.py trains convolutional neural networks on CIFAR10.

The main options are:

  --no-cuda             disable cuda
  --epochs EPOCHS       number of epochs for phase 1 (default: 50)
  --model_name MODEL_NAME
                        Model {VGG19, GoogLeNet, MobileNetV2, SENet18}
  --optimizer OPTIMIZER
                        optimizer (default: SGD) {Adam, SGD}
  --lr LR               learning rate, when used without alrao
  --use_alrao           multiple learning rates
  --minLR MINLR         log10 of the minimum LR in alrao (log_10 eta_min)
  --maxLR MAXLR         log10 of the maximum LR in alrao (log_10 eta_max)
  --nb_class NB_CLASS   number of classifiers used in Alrao (default 10)

More options are available. Check it by running python main_cnn.py --help. For example, to use the script with Alrao on the interval (10**-5, 10) with GoogLeNet, run:

python main_cnn.py --use_alrao --minLR -5 --maxLR 1 --nb_class 10 --model_name GoogLeNet

If you want to train the same model but with SGD with a learning rate 10**-3, run:

python main_cnn.py --lr 0.001 --model_name GoogLeNet

The available models are VGG19, GoogLeNet, MobileNetV2, SENet18.

Sample script for using Alrao with recurrent models on PTB

The script main_rnn.py trains a recurrent neural networks on PTB with a LSTM. By default, the LSTM is trained for word prediction with a backpropagation through time (bptt) of 35. The setup given in the paper is obtained with the following options:

python main_rnn.py --char_prediction --bptt=70

Options given in the previous section are also available, as well as LSTM specific options (number of layers, size of the embedding, etc.). Check it by running python main_rnn.py --help.

Note: in the code, the loss is computed with the natural logarithm, and not with the binary logarithm as mentioned in the paper. Thus, a conversion is necessary.

How to use Alrao on custom models

Custom models

Custom models can be used with some modifications.

First, the custom model has to be split into a pre-classifier class (e.g. PreClassif) and a classifier class (e.g. Classif) in order to be integrated into the class AlraoModel. Note that the classifier is supposed to return log-probabilities. Once done, an instance of AlraoModel can be created with:

preclassif = PreClassif(<args_of_the_preclassifier>)
alrao_model = AlraoModel(preclassif, nb_classifiers, Classif, <args_of_the_classifiers>)

Then the forward method of the pre-classifier is assumed to return either one value or a tuple. If a tuple (x, a, b, ...) is returned, the first element x is supposed to be taken as input of the classifiers. Their outputs are then averaged with a model averaging method (here, the Switch class), which returns y. Thus, the forward method of AlraoModel returns a tuple (y, a, b, ...).

Method forwarding

To make Alrao easier to integrate in a given project, method forwarding is provided. Suppose a model class named Model has a method f, which is regularly called in a code with x, y = some_model.f(a, b). This model has just to be processed as indicated above, and the code is to be changed from:

some_model = Model(...)
...
x, y = some_model.f(a, b)

to:

some_model = AlraoModel(...)
# if 'f' is a preclassifier method:
some_model.method_fwd_preclassifier('f')
# if 'f' is a classifier method:
#some_model.method_fwd_classifiers('f')
...
x, y = some_model.f(a, b)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.