Giter VIP home page Giter VIP logo

rezero's Introduction

ReZero for Deep Neural Networks

ReZero is All You Need: Fast Convergence at Large Depth; ArXiv, March 2020.

Thomas Bachlechner*, Bodhisattwa Prasad Majumder*, Huanru Henry Mao*, Garrison W. Cottrell, Julian McAuley (* denotes equal contributions)

This repository contains the ReZero-Transformer implementation from the paper. It matches Pytorch's Transformer and can be easily used as a drop-in replacement.

Quick Links:

Abstract

Deep networks have enabled significant performance gains across domains, but they often suffer from vanishing/exploding gradients. This is especially true for Transformer architectures where depth beyond 12 layers is difficult to train without large datasets and computational budgets. In general, we find that inefficient signal propagation impedes learning in deep networks. In Transformers, multi-head self-attention is the main cause of this poor signal propagation. To facilitate deep signal propagation, we propose ReZero, a simple change to the architecture that initializes an arbitrary layer as the identity map, using a single additional learned parameter per layer. We apply this technique to language modeling and find that we can easily train ReZero-Transformer networks over a hundred layers. When applied to 12 layer Transformers, ReZero converges 56% faster on enwiki8. ReZero applies beyond Transformers to other residual networks, enabling 1,500% faster convergence for deep fully connected networks and 32% faster convergence for a ResNet-56 trained on CIFAR 10.

Installation

Simply install from pip:

pip install rezero

Pytorch 1.4 or greater is required.

Usage

We provide custom ReZero Transformer layers (RZTX).

For example, this will create a Transformer encoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXEncoderLayer

encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

This will create a Transformer decoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXDecoderLayer

decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = transformer_decoder(tgt, memory)

Make sure norm argument is left as None as to not use LayerNorm in the Transformer.

See https://pytorch.org/docs/master/nn.html#torch.nn.Transformer for details on how to integrate customer Transformer layers to Pytorch.

Tutorials

  1. Training 128 layer ReZero Transformer on WikiText-2 language modeling
  2. Training 10,000 layer ReZero neural network on CIFAR-10 data

Watch for more tutorials in this space.

Citation

If you find rezero useful for your research, please cite our paper:

@inproceedings{BacMajMaoCotMcA20,
    title = "ReZero is All You Need: Fast Convergence at Large Depth",
    author = "Bachlechner, Thomas  and
      Majumder, Bodhisattwa Prasad
      Mao, Huanru Henry and
      Cottrell, Garrison W. and
      McAuley, Julian",
    booktitle = "arXiv",
    year = "2020",
    url = "https://arxiv.org/abs/2003.04887"
}

rezero's People

Contributors

majumderb avatar calclavia avatar mpariente avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.