Giter VIP home page Giter VIP logo

madgrad's Introduction

MADGRAD Optimization Method

A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization

Documentation availiable at https://madgrad.readthedocs.io/en/latest/.

pip install madgrad

Try it out! A best-of-both-worlds optimizer with the generalization performance of SGD and at least as fast convergence as that of Adam, often faster. A drop-in torch.optim implementation madgrad.MADGRAD is provided, as well as a FairSeq wrapped instance. For FairSeq, just import madgrad anywhere in your project files and use the --optimizer madgrad command line option, together with --weight-decay, --momentum, and optionally --madgrad_eps.

The madgrad.py file containing the optimizer can be directly dropped into any PyTorch project if you don't want to install via pip. If you are using fairseq, you need the acompanying fairseq_madgrad.py file as well.

Things to note:

  • You may need to use a lower weight decay than you are accustomed to. Often 0.
  • You should do a full learning rate sweep as the optimal learning rate will be different from SGD or Adam. Best LR values we found were 2.5e-4 for 152 layer PreActResNet on CIFAR10, 0.001 for ResNet-50 on ImageNet, 0.025 for IWSLT14 using transformer_iwslt_de_en and 0.005 for RoBERTa training on BookWiki using BERT_BASE. On NLP models gradient clipping also helped.

Mirror MADGRAD

The mirror descent version of MADGRAD is also included as madgrad.MirrorMADGRAD. This version works extremely well, even better than MADGRAD, on large-scale transformer training. This version is recommended for any problem where the datasets are big enough that generalization gap is not an issue.

As the mirror descent version does not implicitly regularize, you can usually use weight decay values that work well with other optimizers.

Tech Report

Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization

We introduce MADGRAD, a novel optimization method in the family of AdaGrad adaptive gradient methods. MADGRAD shows excellent performance on deep learning optimization problems from multiple fields, including classification and image-to-image tasks in vision, and recurrent and bidirectionally-masked models in natural language processing. For each of these tasks, MADGRAD matches or outperforms both SGD and ADAM in test set performance, even on problems for which adaptive methods normally perform poorly.

@misc{defazio2021adaptivity,
      title={Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization}, 
      author={Aaron Defazio and Samy Jelassi},
      year={2021},
      eprint={2101.11075},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Results

vision nlp

License

MADGRAD is licensed under the MIT License.

madgrad's People

Contributors

adefazio avatar kozistr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

madgrad's Issues

Bug with state initialization

I think there may be a bug with state initialization in the optimizer. Specifically, because the gradients are on the GPU and the states are initialized on the CPU, there is an error coming because of tensors on two different devices. I investigated the code and compared to code of other PyTorch optimizers and noticed a couple of things that could be causing this issue.

Typically when the states are initialized, the torch.zeros_like function is passed with memory_format=torch.preserve_format so that it has the same format as the input tensor, which is usually the model parameters. However, in this case, since it's happening in the __init__ function, the model parameters might not be on the GPU yet. So often, the PyTorch optimizer step includes the initialization code, where there is a check for len(state)==0 in order to initialize.

I changed the optimizer code to follow this sort of pattern and the code runs without issue. I will point out that I am using fastai, so it is possible that this is a fastai-specific issue, but to me it seems like this could be a major issue for other users as well.

Turn off optimization with lr==0.0

Is there a way to turn off optimization for a parameter group by setting lr==0.0 for that group, or via any other method? I found that setting lr==0.0 doesn't work if eps > 0. Would it be possible to make it so that any parameter group with lr==0.0 is skipped?

Does it work for Keras?

I am trying to use it for an experiment on keras and I'm stuck at 'params' factor. What is it exactly and what can be its value?

Optimizers from 1.1 incompatible with 1.2

If I save a Madgrad optimizer from v1.1 as part of a pytorch model, it is not compatible with v1.2. I get the following error (trainer.py is obviously in our code):

  File "/sailhome/horatio/stanza/stanza/models/constituency/trainer.py", line 780, in train_model_one_epoch
    optimizer.step()
  File "/u/nlp/anaconda/main/anaconda3/envs/stanza-1.2/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
    return wrapped(*args, **kwargs)
  File "/u/nlp/anaconda/main/anaconda3/envs/stanza-1.2/lib/python3.7/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/u/nlp/anaconda/main/anaconda3/envs/stanza-1.2/lib/python3.7/site-packages/madgrad/madgrad.py", line 102, in step
    decouple_decay = group["decouple_decay"]

Perhaps group.get(decouple_decay, reasonable_default) would make old models from 1.1 compatible with 1.2

why not AdamW style weight decay

Hello,

While translating your optimizer to Flax (here), I noticed that you are using a traditional weight decay were you add the weight decay to the gradient (here in your implementation):

grad += weight_decay * parameters

Rather than an AdamW style weight decay (which, I believe, is now the default for most optimizers) were you would subtract the weight decay time the learning rate just before returning the parameters:

updated_parameters -= learning_rate * weight_decay * param

Is there a particular reason for that decision ?

How about CIFAR100?

I have tuned many trials but still much worse than SGDM in cifar100.

From the paper you did not experiment CIFAR100, How about CIFAR100?

Speech

Great work!

Planning to run ff transformer network in speech domain overnight with madgrad.
Any heads up?

Compatible with Keras?

I'm wondering if this is compatible with Keras

Also, I'm a bit confused on using this optimizer, is it just a drop in replacement for adam?
for example,

import madgrad

opt = Madgrad(lr=0.001)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.