Giter VIP home page Giter VIP logo

ring-attention-pytorch's Introduction

Ring Attention - Pytorch

Explorations into Ring Attention, from Liu et al. at Berkeley AI.

It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.

I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of RMT.

In addition, the repository also contains the logic for Striped Attention, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.

Appreciation

  • A16Z Open Source AI Grant Program for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

Install

$ pip install ring-attention-pytorch

Usage

import torch
from ring_attention_pytorch import RingAttention

attn = RingAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    causal = True,
    auto_shard_seq = True,
    ring_attn = True,
    ring_seq_size = 512
)

tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)

assert attended.shape == tokens.shape

Test

$ python assert.py

Todo

  • make it work with derived causal mask based on rank and chunk sizes

  • modify flash attention to output intermediates and figure out backwards with recompute and ring passes

  • functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper

  • basic test case with two processes and check for equivalent output and gradients

  • testing

    • make sure key padding mask works
    • make sure causal mask works
    • rotary embeddings, with proper key/value offset depending on ring rank
  • striped attention

    • add the permutating logic before and after transformer
    • add causal masking logic - account for sub bucketing by flash attention
  • fix issue with ring attention when flash buckets > 1

  • move flash attention back to key / value column traversal on outer loop and save on ring communication

    • backwards
    • forwards
  • fix rotary positions for striped ring attention when flash buckets > 1

  • allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers.

  • when doing ring passes, alternate between designated send and receive buffers

  • instead of max ring passes, able to specify lookback in terms of sequence length, and derive number of flash attention bucket + ring passes from that

  • ability to have ring size < world size, sharding the batch and sequence, and doing ring reduce with the correct set of ranks

  • add flash attention kernel version in the presence of cuda

    • for forwards, use modified Triton flash attention forwards that outputs row sums, maxes, and exponentiated weighted sum
    • for backwards, use Tri's flash attention kernels, accumulate dq, dk, dv across rings
    • refactor to have naive ring+flash attention work with (batch, seq, head, dim)
    • handle key padding mask for forwards by translating mask to bias
    • figure out how Tri handles key padding mask for backwards
    • scale output of flash attention forwards on the last ring pass reduce
    • verify backwards working in a100 runpod
    • dk, dv needs to be float32, while kv needs to be float16. see if both can be cast to int before stacked and ring passed all in one go, then reinterpret back to float32 and float16
    • prevent an unnecessary tl.load on the first ring pass
    • cuda backwards pass must have same dq, dk, dv as naive
  • fix naive flash attention backwards

  • validate cuda causal and striped ring attention works

  • find a machine with 8 GPUs and test with a quarter million tokens first

  • think about how to craft a special Dataset that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training

  • add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl

  • batch_isend_irecv in the presence of key padding mask needing ring exchange, but not a big priority

  • figure out how to pytest distributed pytorch

  • use sdp context manager to validate when it is possible to use ring_flash_attn_cuda, otherwise assert out

Citations

@article{Liu2023RingAW,
    title    = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author   = {Hao Liu and Matei Zaharia and Pieter Abbeel},
    journal  = {ArXiv},
    year     = {2023},
    volume   = {abs/2310.01889},
    url      = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
    title   = {Striped Attention: Faster Ring Attention for Causal Transformers},
    author  = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.09431},
    url     = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
    title   = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
    author  = {Dao, Tri},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}

ring-attention-pytorch's People

Contributors

lucidrains avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.