Giter VIP home page Giter VIP logo

lineartransformer's Introduction

LinearTransformer

Pytorch code for reproducing experiments for the following papers:

[1] Transformers learn to implement preconditioned gradient descent for in-context learning. Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra
[2] Linear attention is (maybe) all you need (to understand Transformer optimization). Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Suvrit Sra, Ali Jadbabaie

'simple demonstration.ipynb':

  • Training a 3 layer Linear Transformer with SGD/Adam, covariates have identity covariance
  • Plotting test loss
  • Displaying matrices at end of training + distance to identity (similar to Figure 4 of [1])

'rotation demonstration-Adam.ipynb':

  • Training a 3 layer Linear Transformer with Adam, covariates have non-identity covariance (Adam requires about 100x more steps to converge compared to the identity covariance case)
  • Plotting test loss
  • Displaying matrices at end of training + distance to identity (similar to Figure 4 of [1]) 'rotation demonstration-Adam-p0.ipynb' is similar to 'rotation demonstration-Adam.ipynb', but enforces that the P matrix has top left block = 0

'variable_L_exp.ipynb':

  • Compares n-layer linear Transformer against n-step Gradient Descent/ Preconditioned Gradient Descent, for n = 1,2,3,4, for fixed context length N=20

'variable_N_exp.ipynb':

  • Compares 3-layer linear Transformer against 3-step Gradient Descent/ Preconditioned Gradient Descent, for context length N={2,4,6...20}

'linear_transformer.py' contains definition of the Linear Transformer model, along with some other handy functions.

Quck Start

Setting: training 3 layer linear transformer with Adam/SGD (with clipping), covariates have normal convariance

  1. Run train.ipynb - training linear transformer
  2. Run plot_loss.ipynb - generates loss vs iteration plot
  3. Run plot_stochastic_noise.ipynb - generates stochastic gradient noise histogram
  4. Run plot_condition_number.ipynb - generates robust 5ondition number plot
  5. Run plot_smoothness_vs_gradnorm.ipynb - generates smoothness vs gradienr norm plot

Hyperparameters

mode: method of generating samples (['normal', 'sphere', 'gamma'])

alg: Optimization algorithm (['sgd', 'adam'])

toclip: True if use clipping algorithm. Otherwise, False.

lr: learning rate

sd: random seed

max_iters: maximum number of iterations

n_layer: number of layers of linear transformer

N: number of in-context samples

Learning Rates

Setting (n_layer=3) SGDM (with clipping) Adam (with clipping)
mode='normal', N=5 0.01 0.005
mode='normal', N=20 0.02 0.02
mode='sphere', N=20 5 0.1
mode='gamma', N=20 0.02 0.02

lineartransformer's People

Contributors

kjayahn avatar chengxiang avatar

Stargazers

cyzheng2002 avatar Aytijhya Saha avatar Jeff Carpenter avatar  avatar Jyotishka Ray Choudhury avatar Xuxing Chen avatar Xinzhe Li avatar George De Ath avatar

Watchers

 avatar Suvrit avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.