Giter VIP home page Giter VIP logo

Comments (9)

lhao499 avatar lhao499 commented on May 25, 2024 1

Good Q. We compared with Jax code instead of Triton code. Triton code is more ideal and should be faster, since it allows one to control data movement directly. We are planning to add triton based implementation but no concrete eta yet.

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Just FYI, if you are interested in optimizing speed, a minimal effort way would be replacing blockwise_compute_attn with using jax-triton to call the triton flashattn; and keeping other blockwise operations including blockwise_compute_ffn and blockwise_cross_entropy as is.
This should provide both the memory saving of Blockwise Parallel Transformer and the speed gain of FlashAttention.

from ringattention.

conceptofmind avatar conceptofmind commented on May 25, 2024

Just FYI, if you are interested in optimizing speed, a minimal effort way would be replacing blockwise_compute_attn with using jax-triton to call the triton flashattn; and keeping other blockwise operations including blockwise_compute_ffn and blockwise_cross_entropy as is. This should provide both the memory saving of Blockwise Parallel Transformer and the speed gain of FlashAttention.

Hi @lhao499 ,

Thank you for your profound research on BPT and CoH!

I am currently working on a PyTorch rewrite of BPT.

Do you mind clarifying the part where you say you can replace blockwise_compute_attn with Triton Flash Attention? Do you mean you can do a direct one-to-one replacement by just swapping blockwise_compute_attn with Flash Attention? Or do you mean that the Triton version of Flash Attention needs to be rewritten in a way where it follows the same structure as blockwise_compute_attn?

I will likely open up a new issue soon to further discuss the PyTorch reimplementation.

Thank you again for all of your great work!

Best,

Enrico

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Hi Enrico, thanks for your interest.
Yes, I think blockwise_compute_attn can be swapped with Triton flash attention, nanoGPT has triton flash attention integrated and there are also two other Triton implementations 1 2. For the remaining blockwise operations including blockwise_compute_ffn and blockwise_cross_entropy, a simpler way would be continue using Jax/PyTorch. The reason is that it requires some efforts to fuse blockwise FFN with flash attention for best performance. In addition, with blockwise parallel compiler can do some automatic fusion (albeit not necessarily the most optimal one).

from ringattention.

conceptofmind avatar conceptofmind commented on May 25, 2024

Thank you for the additional clarity. It is good to know that Flash Attention can be used as a drop-in replacement without having to do any rewrite.

from ringattention.

nalzok avatar nalzok commented on May 25, 2024

Hi @lhao499, congratulations on this great piece of work!

We compared with Jax code instead of Triton code.

Can you point me to the FlashAttention implementation in JAX in this repo? It appears to me that you only implemented the memory-efficient transformer, but not FlashAttention.

https://github.com/lhao499/blockwise-parallel-transformer/blob/2f42ba1bc73e91eb6315759706029b1a67b54e6d/README.md?plain=1#L87

there are also two other Triton implementations 1 2.

In case it's relevant, here is a third one: https://github.com/jax-ml/jax-triton/blob/main/examples/fused_attention.py

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Flashattention uses the same method as memory-efficient attention to reduce memory cost, but is implemented with low level kernels to achieve further speed up. In this repo, code is not implemented in low-level kernels as we mainly focus on reducing memory cost, using low-level kernels should be straightforward with Triton or Pallas.

If you would like to use low level kernels, you can use Jax-Triton's Triton or Pallas API on GPUs, or use the Pallas API with mosaic lowering[1][2] on TPUs.

from ringattention.

ZeldaHuang avatar ZeldaHuang commented on May 25, 2024

Hi @lhao499, I saw you compared BPT with memory-efficient attention, and mentioned that flash-attention use the same method as memory-efficient attention.
However, in Appendice B.5 of flash-attention paper , author compared their implementation with memory-efficient attention, and point out flash-attention has 2-4x speed up and smaller total memory requirement.
So I'm curious about the real speed/memory comparison between BPT and flash-attention.

I also found that flash attention seems to conflict with BPT.
Here is the Pseudocode of flash attention:
image
compared with BPT:
image
The outer loop of BPT scan the Q blocks like memory-efficient attention, so it can compute blockwise ffn after inner loop done . The outer loop of flash-attention scan the K,V blocks, it can reduce the total memory requirement by incrementally updates the output instead of store copies of output, as mentioned in Appendice B.5 :

FlashAttention instead incrementally updates the output (Algorithm 1 line 12) after processing each block, so only one copy of the output is needed (instead of 𝐾 copies for 𝐾 blocks). This means that FlashAttention has smaller total memory requirement compared to Rabe and Staats [66].

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Regarding scanning order:
I'm not exactly sure why FlashAttention chose to scan K,V blocks first, but FlashAttention2 chose to scan Q blocks as done in memory-efficient attention and BPT. I will need to take a further look.

Regarding speed:
FlashAttention provides an efficient CUDA implementation for memory-efficient attention, at the time of implementing BPT, there was no public kernel APIs for TPU and we were mainly using TPUs, so we compared BPT with memory-efficient attention.
BPT is compatible with FlashAttention kernel optimization, recently, Jax has added Pallas for kernel APIs, it would cool to try it out.

from ringattention.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.