Giter VIP home page Giter VIP logo

momentumnet's Introduction

Momentum ResNets: Drop-in replacement for any ResNet with reduced memory footprint

GHActions PyPI Downloads

This repository hosts Python code for Momentum ResNets.

See the documentation, our ICML 2021 paper and a 5 min presentation.

Model

Official library for using Momentum Residual Neural Networks [1]. These models extend any Residual architecture (for instance it also work with Transformers) to a larger class of deep learning models that consume less memory. They can be initialized with the same weights as a pretrained ResNet and are promising in fine-tuning applications.

Installation

pip

To install momentumet, you first need to install its dependencies:

$ pip install numpy matplotlib torch

Then install momentumnet with pip:

$ pip install momentumnet

or to get the latest version of the code:

$ pip install git+https://github.com/michaelsdr/momentumnet.git#egg=momentumnet

If you do not have admin privileges on the computer, use the --user flag with pip. To upgrade, use the --upgrade flag provided by pip.

check

To check if everything worked fine, you can do:

$ python -c 'import momentumnet'

and it should not give any error message.

Quickstart

The main class is MomentumNet. It creates a Momentum ResNet for which forward equations can be reversed in closed-form, enabling learning without standard memory consuming backpropagation. This process trades memory for computations.

To get started, you can create a toy Momentum ResNet by specifying the functions f for the forward pass and the value of the momentum term, gamma.

>>> from torch import nn
>>> from momentumnet import MomentumNet
>>> hidden = 8
>>> d = 500
>>> function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d))
>>> mresnet = MomentumNet([function,] * 10, gamma=0.9)

Momentum ResNets are a drop-in replacement for ResNets

We can transform a ResNet into a MomentumNet with the same parameters in two lines of codes. For instance, the following code instantiates a Momentum ResNet with weights of a pretrained Resnet-101 on ImageNet. We set "use_backprop" to False so that activations are not saved during the forward pass, allowing smaller memory consumptions.

>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> from torchvision.models import resnet101
>>> resnet = resnet101(pretrained=True)
>>> mresnet101 = transform_to_momentumnet(resnet, gamma=0.9, use_backprop=False)

Importantly, this method also works with Pytorch Transformers module, specifying the residual layers to be turned into their Momentum version.

>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> transformer = torch.nn.Transformer(num_encoder_layers=6, num_decoder_layers=6)
>>> mtransformer = transform_to_momentumnet(transformer, sub_layers=["encoder.layers", "decoder.layers"], gamma=0.9,
>>>                                          use_backprop=False, keep_first_layer=False)

This initiates a Momentum Transformer with the same weights as the original Transformer.

Memory savings when applying Momentum ResNets to Transformers

Here is a short tutorial showing the memory gains when using Momentum Transformers.

Dependencies

These are the dependencies to use momentumnet:

  • numpy (>=1.8)
  • matplotlib (>=1.3)
  • torch (>= 1.9)
  • memory_profiler
  • vit_pytorch

Cite

If you use this code in your project, please cite:

Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré
Momentum Residual Neural Networks
Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287
https://arxiv.org/abs/2102.07870

momentumnet's People

Contributors

michaelsdr avatar pierreablin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.