Giter VIP home page Giter VIP logo

Comments (12)

tridao avatar tridao commented on August 27, 2024 5

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection.
The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

from mamba.

apoorv2904 avatar apoorv2904 commented on August 27, 2024 2

@tridao selective_scan_fn(u, delta, A, B, C, D) resulted in speed up but its still significantly slower for N=16.

image

from mamba.

albertfgu avatar albertfgu commented on August 27, 2024 1

We decided to leave those linear projections out because they are orthogonal to the main "sequence mixing mechanism" (attention vs scan) that is of interest to benchmark. You're right that the comparisons become slightly harder to control (e.g. what model dimension to use is fair?), but we chose a setting that seemed reasonable to us. No matter what, the timings will only be off by a small constant factor with any other "reasonable" setting of dimensions, which is dwarfed by the linear vs quadratic complexity.

from mamba.

tridao avatar tridao commented on August 27, 2024 1

Q, K, V are bf16 for attention.
u, delta, B, C, z are bf16, A and D are fp32 for scan.

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 27, 2024 1

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

I write a simple script to compare these two component(scan and flashattn2 with causal), and tested it on A100. As instructed, input dim of scan is 4096 and input dim of flashattn is 2048( 32heads * 64 head dim). however, scan is much slower than flashattention2. (fwd: scan is 0.25ms, and flash2 is 0.14ms, fwd+bwd: scan is 1.25ms, flash2 is 0.59ms) Did I make any settings wrong?

import torch
import time

test_bwd=False
batch, length, dim, d_state =1, 2048, 2048, 16
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
u = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
A = torch.randn(dim*2, d_state).to("cuda").requires_grad_(True)
B = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
C = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
D = torch.randn(dim*2).to("cuda").requires_grad_(True)
z = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta_bias = torch.randn(dim*2).to("cuda").requires_grad_(True)
doutssm = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16)
ssm = SelectiveScanFn.apply

for i in range(10):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
print(time.time() - start)

from flash_attn import flash_attn_func

dim_head = 64
n_heads = dim//dim_head
q = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
k = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
v = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
dout = torch.randn(batch, length, n_heads,dim_head).to("cuda").to(torch.bfloat16)

for i in range(10):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
print(time.time() - start)

from mamba.

albertfgu avatar albertfgu commented on August 27, 2024 1

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 27, 2024

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection. The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

And what datatype did you use? When I try to run scan using fp16, it always raises the error:
Traceback (most recent call last):
File "/home/yuqing/mamba/run.py", line 29, in
y = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, True)
RuntimeError: Expected weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 27, 2024

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

it works now, thank you!

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 27, 2024

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

Sorry for the format issue. I've re-edited the code above. I also tested input with D=1024, for fwd, it's scan 0.13ms vs flash 0.08ms, for fwd+bwd, it's scan 0.71ms vs flash 0.35 ms.

from mamba.

apoorv2904 avatar apoorv2904 commented on August 27, 2024

Hi, @tridao and @albertfgu, first of all thank you for releasing both FlashAttention (v1 and v2) and Mamba model source codes including the cuda kernels!

I too had this issue about not being able to reproduce the benchmarks in particular agains flash attention v2. I tried several settings. (D=768, 1024, 2048) and for N/d_state=16, flash attention was significantly faster than scan. Only at N=4, I start to see the curves reported in the paper. In particular, for N=16 the scan is about 2X slower.

Following are the times in ms that I see.
image

It would be immensely useful if you could spare some time to please review the mamba benchmark below or provide few more details to reproduce the benchmark. Thanks @xiayuqing0622 for the starting code.

Environment:
- A100 80 GB
- pytorch 2.1 / cuda 11.8
def benchmark_mamba(batch, head, length, dim_head, d_state):
   from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
   from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda
   from einops import rearrange, repeat

   d_model = dim_head * head
   expand = 2
   d_inner = d_model * expand
   device = "cuda"

   # S4D real initialization
   A = repeat(
       torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
       "n -> d n",
       d=d_inner,
   ).contiguous()
   A_log = torch.log(A)  # Keep A_log in fp32

   x = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   z = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   delta = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   delta_bias = torch.randn(d_inner).to("cuda").requires_grad_(True)
   A = -torch.exp(A_log.float())  # (d_inner, d_state)
   B = (
       torch.randn(batch, 1, d_state, length)
       .to("cuda")
       .to(torch.bfloat16)
       .requires_grad_(True)
   )
   C = (
       torch.randn(batch, 1, d_state, length)
       .to("cuda")
       .to(torch.bfloat16)
       .requires_grad_(True)
   )
   D = torch.ones(d_inner, device=device)  # Keep in fp32
   delta_softplus = True

   ms = triton.testing.do_bench(
       lambda: selective_scan_cuda.fwd(
           x, delta, A, B, C, D, z, delta_bias, delta_softplus
       ),
       warmup=100,
   )
   return ms

The full code is below but please feel free to ignore the rest. Here is the code

import itertools
from math import sqrt

import pandas
import torch
from tqdm import tqdm
import triton

from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func


def get_inputs(B, H, L, E=64, ret_padding_mask=False, dtype=torch.float32):
    q = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
    k = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
    v = torch.rand((B, H, L, E), device="cuda", dtype=dtype)

    input_lengths = torch.randint(1, L, (B,), device=q.device).long()
    input_lengths[-1] = L
    padding_mask = torch.zeros((B, L), dtype=q.dtype, device=q.device)
    padding_mask[
        (
            torch.arange(padding_mask.shape[0], device=padding_mask.device),
            input_lengths - 1,
        )
    ] = 1
    padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
    if not ret_padding_mask:
        padding_mask = None
    return (q, k, v), padding_mask
    
def flash_attn_forward(queries, keys, values, padding_mask=None):
    qkv = torch.stack([queries, keys, values], dim=2)
    qkv = qkv.permute(0, 3, 2, 1, 4)
    B, T, _, H, D = qkv.shape
    scale = 1.0 / sqrt(D)

    if padding_mask is not None:
        # unpad_input expectes True to correspond to valid indices and False to invalid
        qkv, indices, cu_q_lens, max_s = unpad_input(qkv, ~padding_mask)
        packed_res = flash_attn_varlen_qkvpacked_func(
            qkv,
            cu_q_lens,
            max_s,
            dropout_p=0.0,
            softmax_scale=scale,
            causal=False,
            alibi_slopes=None,
            deterministic=False,
        )
        res = pad_input(packed_res, indices, B, T)
        res = res.transpose(1, 2)
    else:
        res = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=0.0,
            softmax_scale=scale,
            causal=False,
            alibi_slopes=None,
            deterministic=False,
        )
        res = res.transpose(1, 2)  # B x T x H x D -> B x H x T x D
    return res

    
def benchmark_flash(q, k, v, padding_mask):
    dim_E = q.shape[-1]
    H = q.shape[1]
    E = dim_E * H
    ms = triton.testing.do_bench(
        lambda: flash_attn_forward(q, k, v, padding_mask=padding_mask), warmup=100
    )
    return ms


if __name__ == "__main__":
    batch_sizes = [16]
    heads = [12, 16, 32]
    time_steps = [1000, 1600, 3200, 6400]
    get_padding_masks = [True, False]
    d_states = [2, 4, 8, 16]
    dtypes = [torch.bfloat16]
    E = 64

    results = []

    for B, H, L, pm, dtype in tqdm(
        itertools.product(batch_sizes, heads, time_steps, get_padding_masks, dtypes)
    ):
        (q, k, v), padding_mask = get_inputs(
            B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
        )
        ms = benchmark_flash(q, k, v, padding_mask)
        results.append(
            {
                "name": "flash",
                "batch_size": B,
                "nheads": H,
                "seq_len": L,
                "dim": H * E,
                "padding": pm,
                "dtype": dtype,
                "ms": ms,
            }
        )

    for B, H, L, pm, d_state, dtype in tqdm(
        itertools.product(
            batch_sizes, heads, time_steps, get_padding_masks, d_states, dtypes
        )
    ):
        (q, k, v), padding_mask = get_inputs(
            B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
        )

        ms = benchmark_mamba(B, H, L, E, d_state)
        results.append(
            {
                "name": f"mamba-{d_state}",
                "batch_size": B,
                "nheads": H,
                "seq_len": L,
                "dim": H * E,
                "padding": pm,
                "dtype": dtype,
                "ms": ms,
            }
        )

    df = pandas.DataFrame(results)
    piv = df.pivot(
        columns="name",
        values="ms",
        index=["dtype", "padding", "batch_size", "nheads", "seq_len"],
    )
    print(piv.sort_index().round(3))

from mamba.

tridao avatar tridao commented on August 27, 2024

Try selective_scan_fn(u, delta, A, B, C, D) (no z, delta_bias, delta_softplus) to see if that makes a difference?

from mamba.

llmexperiment avatar llmexperiment commented on August 27, 2024

@tridao selective_scan_fn(u, delta, A, B, C, D) resulted in speed up but its still significantly slower for N=16.

image

HI @apoorv2904 , are you able to reproduce the results? If so could you please share how you reproduced the result?

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.