Giter VIP home page Giter VIP logo

Comments (17)

yzhangcs avatar yzhangcs commented on August 27, 2024 1

Since you set expand_k=0.5, the key_dim becomes 32 * 0.5 = 16, resulting in the head_k_dim of 8, which is not supported by Triton matmuls.
Also it's not recommended to use head_dim < 64, as I said before, which would be padded and can lead to a great waste
of FLOPs.

from flash-linear-attention.

sustcsonglin avatar sustcsonglin commented on August 27, 2024

Hi, what is your head dimension?

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024
  1. Is 16 not supported?

I tried changing to 32 and it works. However, it seems to be much slower than Mamba and Flash Attention. My GPU is RTX3090. And my input size is (1,19384,32). It is a hierarchical network whereby the number of tokens half and the channel size double in each stage.

What am I doing wrong?

from flash-linear-attention.

sustcsonglin avatar sustcsonglin commented on August 27, 2024

head dim too small. try to use fused_recurrent mode?

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

fused_recurrent is even slower.

Does the advantage of flash linear attention only occur when having large channel dim? Throughout my network, the channel dims are only from 32 to 512 (32, 64, 128, 256, 512). Currently, it is 2x slow than flash attention.

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

I tried using a simple linear attention like this and it is also much faster. Not sure why.

class LinearAttention(nn.Module):
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
    """

    def __init__(self, dim, input_resolution, num_heads, qkv_bias=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.elu = nn.ELU()

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, N, C)
        """
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        num_heads = self.num_heads
        head_dim = c // num_heads

        qk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)
        q, k, v = qk[0], qk[1], x
        # q, k, v: b, n, c

        q = self.elu(q) + 1.0
        k = self.elu(k) + 1.0
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)

        z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))
        x = q @ kv * z

        x = x.transpose(1, 2).reshape(b, n, c)

        return x

from flash-linear-attention.

sustcsonglin avatar sustcsonglin commented on August 27, 2024

Hello, could you give me some code snippets so that i can test the speed?

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

@sustcsonglin

code:

import torch
import flash_attn
from fla.models import GLAConfig
from fla.layers.gla import GatedLinearAttention
from mamba_ssm import Mamba

HH = [1,2,4,8,16,32]
CC = [32,64,128,256,512,1024]
TT = [197912,197912,197912,18432,18432,18432]


with torch.no_grad():
    for i in range(len(HH)):
        H = HH[i]
        C = CC[i]
        T = TT[i]
        print(H, C, T)

        scale = (C // H) ** -0.5
        x = torch.randn(1,T,C).cuda()


        if T == 197912:
            cu_seqlens = torch.tensor([     
                0,   1024,   2048,   3072,   4096,   5120,   6144,   7168,   8192,
                9216,  10240,  11264,  12288,  13312,  14336,  15360,  16384,  17408,
                18432,  19456,  20480,  21504,  22528,  23552,  24576,  25600,  26624,
                27648,  28672,  29696,  30720,  31744,  32768,  33792,  34816,  35840,
                36864,  37888,  38912,  39936,  40960,  41984,  43008,  44032,  45056,
                46080,  47104,  48128,  49152,  50176,  51200,  52224,  53248,  54272,
                55296,  56320,  57344,  58368,  59392,  60416,  61440,  62464,  63488,
                64512,  65536,  66560,  67584,  68608,  69632,  70656,  71680,  72704,
                73728,  74752,  75776,  76800,  77824,  78848,  79872,  80896,  81920,
                82944,  83968,  84992,  86016,  87040,  88064,  89088,  90112,  91136,
                92160,  93184,  94208,  95232,  96256,  97280,  98304,  99328, 100352,
                101376, 102400, 103424, 104448, 105472, 106496, 107520, 108544, 109568,
                110592, 111616, 112640, 113664, 114688, 115712, 116736, 117760, 118784,
                119808, 120832, 121856, 122880, 123904, 124928, 125952, 126976, 128000,
                129024, 130048, 131072, 132096, 133120, 134144, 135168, 136192, 137216,
                138240, 139264, 140288, 141312, 142336, 143360, 144384, 145408, 146432,
                147456, 148480, 149504, 150528, 151552, 152576, 153600, 154624, 155648,
                156672, 157696, 158720, 159744, 160768, 161792, 162816, 163840, 164864,
                165888, 166912, 167936, 168960, 169984, 171008, 172032, 173056, 174080,
                175104, 176128, 177152, 178176, 179200, 180224, 181248, 182272, 183296,
                184320, 185344, 186368, 187392, 188416, 189440, 190464, 191488, 192512,
                193536, 194560, 195584, 196608, 197632, 198656], dtype=torch.int32).cuda()
        else:
            cu_seqlens = torch.tensor([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,  9216,
                10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408, 18432], dtype=torch.int32).cuda()

        qkv_proj = torch.nn.Linear(C,C*3).cuda()

        total = 0
        count = 0
        for i in range(100):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            qkv = qkv_proj(x)
            feat = flash_attn.flash_attn_varlen_qkvpacked_func(
                qkv.half().reshape(-1, 3, H, C // H),
                cu_seqlens,
                max_seqlen=1024,
                dropout_p=0.0,
                softmax_scale=scale,
            ).reshape(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            total +=  elap
            count += 1

        print('flash', total / count)

        gla_attn = GatedLinearAttention(
            hidden_size=C,
            expand_k=0.5,
            expand_v=1.0,
            num_heads=H,
            gate_fn="swish",
            mode='fused_chunk',
            fuse_norm=True,
        ).cuda()

        total = 0
        count = 0
        for i in range(100):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            feat, _, _ = gla_attn(x)
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            total +=  elap
            count += 1

        print('gla', total / count)

        mamba = Mamba(
            d_model=C, # Model dimension d_model
            d_state=1,  # SSM state expansion factor
            d_conv=4,    # Local convolution width
            expand=2,    # Block expansion factor
        ).to("cuda")

        total = 0
        count = 0
        for i in range(100):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            y = mamba(x)
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            total +=  elap
            count += 1

        print('mamba', total / count)

results:

1 32 197912
flash 1.9978342366218567
gla 30.399498472213747
mamba 1.7825075244903565
2 64 197912
flash 2.9101670598983764
gla 14.390999011993408
mamba 1.9755417597293854
4 128 197912
flash 4.381767692565918
gla 16.998656005859374
mamba 4.925788168907165
8 256 18432
flash 0.9251532834768296
gla 2.8886323380470276
mamba 1.3630463945865632
16 512 18432
flash 2.648555555343628
gla 5.333882894515991
mamba 4.275425262451172
32 1024 18432
flash 8.642979860305786
gla 14.024908800125122
mamba 15.80884991645813

from flash-linear-attention.

sustcsonglin avatar sustcsonglin commented on August 27, 2024

why don't you use half precision for gla?

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

i tried casting x to half as follows:

    gla_attn = GatedLinearAttention(
            hidden_size=C,
            expand_k=0.5,
            expand_v=1.0,
            num_heads=H,
            gate_fn="swish",
            mode='fused_chunk',
            fuse_norm=True,
        ).half().cuda()

        total = 0
        count = 0
        for i in range(100):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            feat, _, _ = gla_attn(x.half())
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            total +=  elap
            count += 1

        print('gla', total / count)

output on RTX3090:

1 32 197912
flash 2.0588236618041993
gla 30.708438749313355
mamba 1.8067046213150024
2 64 197912
flash 2.9214208006858824
gla 11.118141412734985
mamba 1.9925299274921418
4 128 197912
flash 4.3777637910842895
gla 12.742543354034424
mamba 4.89646080493927
8 256 18432
flash 0.8793087983131409
gla 9.139538043737412
mamba 1.3001830399036407
16 512 18432
flash 2.3843328285217287
gla 8.640337998867034
mamba 3.9147417974472045
32 1024 18432
flash 8.205137915611267
gla 5.838223357200622
mamba 15.277168626785278

seems like gla is only faster than flash when 32 1024 18432

from flash-linear-attention.

yzhangcs avatar yzhangcs commented on August 27, 2024

@yxchng Could you report the throughputs on your gpus running this script
https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/benchmark_training_throughput.py

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024
python3 benchmarks/benchmark_training_throughput.py --batch_size 1
Initializing retnet model from the config:
RetNetConfig {
  "attn_mode": "fused_chunk",
  "bos_token_id": 1,
  "conv_size": 4,
  "elementwise_affine": true,
  "eos_token_id": 2,
  "expand_k": 1,
  "expand_v": 2,
  "feature_map": null,
  "fuse_cross_entropy": true,
  "fuse_norm": true,
  "hidden_act": "swish",
  "hidden_ratio": 2,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "max_position_embeddings": 2048,
  "model_type": "retnet",
  "norm_eps": 1e-06,
  "num_heads": 8,
  "num_hidden_layers": 24,
  "num_kv_heads": null,
  "share_conv_kernel": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.40.0",
  "use_cache": true,
  "use_output_gate": true,
  "use_short_conv": false,
  "vocab_size": 32000
}

RetNetForCausalLM(
  (model): RetNetModel(
    (embeddings): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-23): 24 x RetNetBlock(
        (attn_norm): RMSNorm(2048, eps=1e-06)
        (attn): MultiScaleRetention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=4096, bias=False)
          (g_proj): Linear(in_features=2048, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2048, bias=False)
          (g_norm_swish_gate): FusedRMSNormSwishGate(512, eps=1e-06)
          (rotary): RotaryEmbedding()
        )
        (mlp_norm): RMSNorm(2048, eps=1e-06)
        (mlp): RetNetMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=2816, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
      )
    )
    (norm): RMSNorm(2048, eps=1e-06)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1351827456 (1.26GiB)
Allocated memory after initialization: 2.52GiB
Max memory allocated: 12.95GiB: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00,  1.40it/s]
Thoughput:    4978.12 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.43it/s]

from flash-linear-attention.

sustcsonglin avatar sustcsonglin commented on August 27, 2024

Can you try to have some warmup steps before the testing? The autotune process in GLA need sweeping for different seq_len and model_dim

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

I modify the code to add warmup as follows (first 100 not tracked, only track the next 200):

code:

import torch
import flash_attn
from fla.models import GLAConfig
from fla.layers.gla import GatedLinearAttention
from mamba_ssm import Mamba

HH = [1,2,4,8,16,32]
CC = [32,64,128,256,512,1024]
TT = [197912,197912,197912,18432,18432,18432]


with torch.no_grad():
    for i in range(len(HH)):
        H = HH[i]
        C = CC[i]
        T = TT[i]
        print(H, C, T)

        scale = (C // H) ** -0.5
        x = torch.randn(1,T,C).cuda()


        if T == 197912:
            cu_seqlens = torch.tensor([     
                0,   1024,   2048,   3072,   4096,   5120,   6144,   7168,   8192,
                9216,  10240,  11264,  12288,  13312,  14336,  15360,  16384,  17408,
                18432,  19456,  20480,  21504,  22528,  23552,  24576,  25600,  26624,
                27648,  28672,  29696,  30720,  31744,  32768,  33792,  34816,  35840,
                36864,  37888,  38912,  39936,  40960,  41984,  43008,  44032,  45056,
                46080,  47104,  48128,  49152,  50176,  51200,  52224,  53248,  54272,
                55296,  56320,  57344,  58368,  59392,  60416,  61440,  62464,  63488,
                64512,  65536,  66560,  67584,  68608,  69632,  70656,  71680,  72704,
                73728,  74752,  75776,  76800,  77824,  78848,  79872,  80896,  81920,
                82944,  83968,  84992,  86016,  87040,  88064,  89088,  90112,  91136,
                92160,  93184,  94208,  95232,  96256,  97280,  98304,  99328, 100352,
                101376, 102400, 103424, 104448, 105472, 106496, 107520, 108544, 109568,
                110592, 111616, 112640, 113664, 114688, 115712, 116736, 117760, 118784,
                119808, 120832, 121856, 122880, 123904, 124928, 125952, 126976, 128000,
                129024, 130048, 131072, 132096, 133120, 134144, 135168, 136192, 137216,
                138240, 139264, 140288, 141312, 142336, 143360, 144384, 145408, 146432,
                147456, 148480, 149504, 150528, 151552, 152576, 153600, 154624, 155648,
                156672, 157696, 158720, 159744, 160768, 161792, 162816, 163840, 164864,
                165888, 166912, 167936, 168960, 169984, 171008, 172032, 173056, 174080,
                175104, 176128, 177152, 178176, 179200, 180224, 181248, 182272, 183296,
                184320, 185344, 186368, 187392, 188416, 189440, 190464, 191488, 192512,
                193536, 194560, 195584, 196608, 197632, 198656], dtype=torch.int32).cuda()
        else:
            cu_seqlens = torch.tensor([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,  9216,
                10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408, 18432], dtype=torch.int32).cuda()

        qkv_proj = torch.nn.Linear(C,C*3).cuda()

        total = 0
        count = 0
        for i in range(300):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            qkv = qkv_proj(x)
            feat = flash_attn.flash_attn_varlen_qkvpacked_func(
                qkv.half().reshape(-1, 3, H, C // H),
                cu_seqlens,
                max_seqlen=1024,
                dropout_p=0.0,
                softmax_scale=scale,
            ).reshape(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100:
                total +=  elap
                count += 1

        print('flash', total / count)

        gla_attn = GatedLinearAttention(
            hidden_size=C,
            expand_k=0.5,
            expand_v=1.0,
            num_heads=H,
            gate_fn="swish",
            mode='fused_chunk',
            fuse_norm=True,
        ).half().cuda()

        total = 0
        count = 0
        for i in range(300):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            feat, _, _ = gla_attn(x.half())
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100:
                total +=  elap
                count += 1

        print('gla', total / count)

        mamba = Mamba(
            d_model=C, # Model dimension d_model
            d_state=1,  # SSM state expansion factor
            d_conv=4,    # Local convolution width
            expand=2,    # Block expansion factor
        ).to("cuda")

        total = 0
        count = 0
        for i in range(300):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            y = mamba(x)
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100:
                total +=  elap
                count += 1

        print('mamba', total / count)

results:

1 32 197912
flash 1.4752803897138815
gla 10.771980899063188
mamba 1.3009894331495966
2 64 197912
flash 2.881798413530666
gla 11.33516543594437
mamba 2.0023882712551098
4 128 197912
flash 4.402350919330539
gla 12.739239247000997
mamba 4.929762396980171
8 256 18432
flash 0.9214096051364687
gla 2.5598970801387
mamba 1.362964578010329
16 512 18432
flash 2.412821833212771
gla 2.9099712982848662
mamba 3.944458279777412
32 1024 18432
flash 8.064108043459791
gla 5.63640474434474
mamba 13.925597262741933

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

Another question (i have actually asked above): just to confirm, is head 16 channel not supported?

from flash-linear-attention.

yzhangcs avatar yzhangcs commented on August 27, 2024

@yxchng 16 channel is ok I think. The loaded tensor shape in the kernel is 64 x 64 at least in our kernel, and padding will happen for inputs smaller than this block size

from flash-linear-attention.

yxchng avatar yxchng commented on August 27, 2024

@yzhangcs but i am facing AssertionError('All values in both first input shape ([constexpr[16], constexpr[8]]) and second input shape ([constexpr[8], constexpr[16]]) must be >= 16!') when I change HH in the above code from HH = [1,2,4,8,16,32] to HH = [2,4,8,16,32,64], i.e. from head dim 32 to head dim 16. How can I resolve this error?

from flash-linear-attention.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.