Giter VIP home page Giter VIP logo

gradientaccumulator's Introduction

GradientAccumulator

Seemless gradient accumulation for TensorFlow 2

Pip Downloads PyPI version License DOI CI

GradientAccumulator enables gradient accumulation (GA) by overloading the train_step of any given tf.keras.Model, to update correctly according to a user-specified number of accumulation steps. GA enables theoretically infinitely large batch size, with the same memory consumption as for a regular mini batch, at the cost of increased runtime. To improve runtime, mixed precision is supported. As batch normalization is not natively compatible with GA, support for adaptive gradient clipping has been added as an alternative.

Package is compatible with and have been tested against TF >= 2.2 and Python >= 3.6 (tested with 3.6-3.10), and works cross-platform (Ubuntu, Windows, macOS).

Install

Stable release from PyPI:

pip install gradient-accumulator

Or from source:

pip install git+https://github.com/andreped/GradientAccumulator

Usage

from gradient_accumulator.GAModelWrapper import GAModelWrapper
from tensorflow.keras.models import Model

model = Model(...)
model = GAModelWrapper(accum_steps=4, inputs=model.input, outputs=model.output)

Then simply use the model as you normally would!

Mixed precision

There has also been added experimental support for mixed precision:

from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import Adam

mixed_precision.set_global_policy('mixed_float16')
model = GAModelWrapper(accum_steps=4, mixed_precision=True, inputs=model.input, outputs=model.output)

opt = Adam(1e-3, epsilon=1e-4)
opt = mixed_precision.LossScaleOptimizer(opt)

If using TPUs, use bfloat16 instead of float16, like so:

mixed_precision.set_global_policy('mixed_bfloat16')

There is also an example of how to use gradient accumulation with mixed precision here.

Adaptive gradient clipping

There has also been added support for adaptive gradient clipping, based on this implementation:

model = GAModelWrapper(accum_steps=4, use_agc=True, clip_factor=0.01, eps=1e-3, inputs=model.input, outputs=model.output)

The hyperparameters values for clip_factor and eps presented here are the default values.

Model format

It is recommended to use the SavedModel format when using this implementation. That is because the HDF5 format is only compatible with TF <= 2.6 when using the model wrapper. However, if you are using older TF versions, both formats work out-of-the-box. The SavedModel format works fine for all versions of TF 2.x

Disclaimer

In theory, one should be able to get identical results for batch training and using gradient accumulation. However, in practice, one may observe a slight difference. One of the cause may be when operations are used (or layers/optimizer/etc) that update for each step, such as Batch Normalization. It is not recommended to use BN with GA, as BN would update too frequently. However, you could try to adjust the momentum of BN (see here).

It was also observed a small difference when using adaptive optimizers, which I believe might be due to how frequently they are updated. Nonetheless, for the optimizers, the difference was quite small, and one may approximate batch training quite well using our GA implementation, as rigorously tested here).

TODOs:

  • Add multi-GPU support

Acknowledgements

The gradient accumulator model wrapper is based on the implementation presented in this thread on stack overflow.

The adaptive gradient clipping method is based on the implementation by @sayakpaul.

This repository serves as an open solution for everyone to use, until TF/Keras integrates a proper solution into their framework(s).

Troubleshooting

Overloading of train_step method of tf.keras.Model was introduced in TF 2.2, hence, this code is compatible with TF >= 2.2.

Also, note that TF depends on different python versions. If you are having problems getting TF working, try a different TF version or python version.

For TF 1, I suggest using the AccumOptimizer implementation in the H2G-Net repository instead, which wraps the optimizer instead of overloading the train_step of the Model itself (new feature in TF2).

How to cite

If you use this package in your research, please, cite this reference:

@software{andre_pedersen_2022_7023582,
  author       = {André Pedersen and
                  David Bouget},
  title        = {andreped/GradientAccumulator: v0.2.1},
  month        = aug,
  year         = 2022,
  publisher    = {Zenodo},
  version      = {v0.2.1},
  doi          = {10.5281/zenodo.7023582},
  url          = {https://doi.org/10.5281/zenodo.7023582}}

gradientaccumulator's People

Contributors

andreped avatar chaithyagr avatar dbouget avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.