Giter VIP home page Giter VIP logo

metal-flash-attention's Introduction

Metal FlashAttention

A faster alternative to Metal Performance Shaders, a reference implementation of modern GPU algorithms, and a step toward defragmenting the AI ecosystem.

Algorithms:

  • Attention
    • Dense (90.5% ALU)
    • Block-Sparse
  • GEMM
    • FP16 (93.3% ALU)
    • FP32 (87.2% ALU)
    • Fused Biases

Usage

Progamming Language MFA Supports MPSGraph Supports PyTorch Supports
CPU C++ (metal-cpp)
GPU C++ (Indirect Command Buffers)
Swift (iPadOS, Playgrounds)
Swift (macOS, Xcode)
Predecessor to Swift not tested

Usage:

  • Download Xcode 14.2 from the Apple developer tools archive
    • Copy into /Applications/Xcode 14.2.app, side by side with the existing Xcode installation /Applications/Xcode.app
  • Run the Swift script to compile libMetalFlashAttention.metallib
    • Enter this repository from Terminal and type swift build.swift
  • Read the API specification
  • Generate Metal shader variants at runtime

Alternatively:

  • Download the newest version of Xcode
  • Fetch the Metal library from GitHub releases
  • Run the unit tests from this repository

Performance

SGEMM, every square matrix from 1–1536:

Max GFLOPS achieved

HGEMM, every square matrix from 1–2048:

Max GFLOPS achieved

GEMM

Scaling by square size:

  • Matrix M: every even integer
  • Matrix N: every even integer
  • Matrix K: every even integer
  • For 2x batched, every multiple of 4
  • For very large square matrices, granularity varies
Function Constant Value
M_splits 2
N_splits 2
M_simd Block M / M_splits
N_simd Block N / N_splits
K_simd Block K
Precision Block M Block N Block K
Float32 32 32 32
Float32 48 48 24
Float16 32 32 32
Float16 48 48 32
Size Start Size End Duplicate Commands/Encoder Trials
1 190 256 16
192 254 128 16
256 382 64 16
384 510 32 16
512 766 16 16
768 1022 8 16
1024 1534 4 16
1536 2048 2 16

Float32 Utilization (NN)

Float32 Utilization (NN)

Float32 Utilization (NT)

Float32 Utilization (NT)

Float32 Utilization (NT, Large)

Float32 Utilization (NT)

Float16 Utilization (NN)

Float16 Utilization (NN)

Float16 Utilization (NT, 2x Batched)

Float16 Utilization (NT, 2x Batched)

Float16 Utilization (NTN, 2x Batched, Bias)

Float16 Utilization (NTN, 2x Batched, Bias)

Attention

Setup:

  • Sequence dimension:
    • R = rows (output sequence length)
    • C = columns (input sequence length)
    • R = C
  • Masking:
    • Only MFA supports block-sparse masks.
    • For "scaling by sparsity", sparse block size equals GEMM block size.

Scaling by sequence length:

  • Masking:
    • No mask
    • Dense Mask: triangular mask
    • Sparse Mask: triangular mask, summarized by block-sparse mask
  • Sequence length:
    • Small sequences: every multiple of 4
    • Large sequences: every multiple of 64
    • Causal mask: every even integer
  • Head size: 64
  • Head count:
    • Small sequences: 10
    • Large sequences: 5
    • Causal mask: 10

Scaling by head size:

  • Masking: dense, no mask
  • Sequence length 4096
  • Head size: every integer
    • ≤64: every integer
    • >64: every roundUpToPowerOf2(D/64) integers
  • Head count: 8
Function Constant Value
Q_trans
K_trans
V_trans
O_trans
R_splits TBD
R_simd Block R / R_splits
C_simd Block C
D_simd $$8 \times \left \lceil{ \frac{D}{8} }\right \rceil $$

Float32 Sequence Scaling (Small)

FlashAttention (F32, H=10, D=64)

Float16 Sequence Scaling (Small)

Dense: Stable Diffusion XL outermost attention layer @ 512x512 (sequence length = 1024)

FlashAttention (F16, H=10, D=64)

Float16 Sequence Scaling (Large)

Dense: Stable Diffusion 2 outermost attention layer @ 512x512 (sequence length = 4096)

FlashAttention (F16, H=5, D=64)

Float32 Sequence Scaling (Causal Mask)

FlashAttention (F32, H=10, D=64)

Float16 Sequence Scaling (Causal Mask)

FlashAttention (F16, H=10, D=64)

FlashAttention (F16, H=10, D=64)

Float16 Head Scaling

Dense: Stable Diffusion 1 outermost attention layer @ 512x512 (head size = 40)

FlashAttention (F16, R=C=4096, H=8)

Roadmap

Releases:

  • v0.1.0-alpha
    • Initial release, only non-batched GEMM without fused transposes
  • v0.2.0-alpha
    • Fused transposes for A and B
    • Batched GEMM
  • v1.0.0
    • Attention: dense and block-sparse
  • v1.0.1
    • GEMM: fused biases

Prospective Future Goals:

  • Tune the existing GEMM and Attention kernels for new A17/M3 hardware
  • Kahan block-summation with double-single accumulate, in a manner portable to other vendors

metal-flash-attention's People

Contributors

ivarflakstad avatar liuliu avatar philipturner avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

metal-flash-attention's Issues

Weird performance when using shared memory in GEMV

I try to optimize GEMV using shared memory to speed up I\O,theoretically speaking,GEMV with sram will have better bandwidth. BUT here comes a weird performance result.

Device: M2 Ultra 128GB
kernel cost from: GPUEndTime and GPUStartTime

  1. Fistly, i build a xcode metal project for original GEMV(your codes) and sram GEMV(my codes),I found sram GEMV is 30% faster than original GEMV;
// transA = false, transB = true.
// and this optimization on performance is as my wish
gemv [1,2048] @ [4096,2048] **0.098ms(original) --> 0.068ms(sram)**
gemv [1,2048] @ [11001,2048] **0.271ms(original) --> 0.195ms(sram)**
  1. Secondly, I add my sram GEMV kernel into your project(because i want to combine them into one metal lib),and call then in my other C++\OC project,then comes the strange thing:
// original much fast than xcode project test, and even faster than sram GEMV
gemv [1,2048] @ [4096,2048] **0.040ms(original) vs. 0.047ms(sram)**
gemv [1,2048] @ [11001,2048] **0.175ms(original) vs. 0.173ms(sram)**
  1. my kernel code is :
    warpPerBlock: 4, GridSize: {UP_ROUND(K, warpPerBlock), 1, 1},GroupSize: {32 * warpPerBlock, 1, 1},
// ONLY support M = 1, tranA = false, transB = true now.
template <typename T, int Align>
void _gemv_sram_impl(device T *A [[buffer(0)]],
                device T *B [[buffer(1)]],
                device T *C [[buffer(2)]],
                device void *D [[buffer(3), function_constant(use_activation)]],
                
                threadgroup T *threadgroup_block [[threadgroup(0)]],
                constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
                constant uint *activation_type [[buffer(13), function_constant(fused_activation)]],
                uint3 gid [[threadgroup_position_in_grid]],
                ushort warp_num [[dispatch_simdgroups_per_threadgroup]],
                ushort sidx [[simdgroup_index_in_threadgroup]],
                ushort lane_id [[thread_index_in_simdgroup]])
{
    if (gid.x * warp_num + sidx >= N || gid.y >= M) return;
    if (batched) {
        // TODO: Re-compute every inner loop iteration for FP64 accumulate.
        ulong3 offsets = matrix_offsets[gid.z].xyz;
        A = (device T*)((device uchar*)A + offsets[0]);
        B = (device T*)((device uchar*)B + offsets[1]);
        C = (device T*)((device uchar*)C + offsets[2]);
    }
    
    B += gid.x * warp_num * K;
    
    C += gid.y * N  + gid.x * warp_num + sidx;
    T acc_sum = 0;
    device vec<T, Align> * Aalign = (device vec<T, Align> *)A;
    device vec<T, Align> * Balign = (device vec<T, Align> *)B;
    // move data into smem
    threadgroup vec<T, Align> * smem = (threadgroup vec<T, Align> *)threadgroup_block;
    for (uint k = sidx * 32 + lane_id; k < K / Align; k += 32 * warp_num) {
        smem[k] = Aalign[k];
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    for (uint k = lane_id; k < K / Align; k += 32) {
        device vec<T, Align> * BalignSIMD = Balign + K / Align * sidx;
        for (uint i = 0; i < Align; ++i) {
            acc_sum += smem[k][i] * BalignSIMD[k][i];
        }
    }
    T all_sum = simd_sum(acc_sum);
    if (lane_id == 0) {
        device T* BWarp = B + sidx * K;
        for (uint k = Align * (K / Align); k < K; ++k) {
            all_sum += A[k] * BWarp[k];
        }
        if (use_bias) {
            // not supported now...
        }
        if (fused_activation) {
            // not supported now...
        }
        *C = all_sum;
    }
}

Question:

  1. Why is different between xcode testbed and metallib call?sram GEMV basiclly the same in xcode and metallib call, but orignal GEMV is much better in metallib call
  2. Is there any compiling optimization i missed in sram GEMV?
  3. According to my code and the situation i described, could you give me some advise about the potential cause of the performance gap?

Thank you for your help!

Accuracy issues due to attention_matrix accumulated at half-precision & softmax_scale (alpha) applied after qk

An accuracy issue arises during integration with SSD-1B model. q, k can be large enough that q*k can exceed half-precision range. This is OK because the scale usually applied on q or on both q and k like new_q = sqrt(scale) * q, new_k = sqrt(scale) * k. However in MFA attention kernel implementation, we apply alpha only after q * k is done, hence cause nan issue.

This can be reproduced with the tensors extracted from SSD-1B computation and with following s4nnc code:

import NNC

let graph = DynamicGraph()

graph.withNoGrad {
  graph.openStore("/Users/liu/Desktop/reprod_tensor.sqlite3") {
    guard let _q = $0.read("q"), let _k = $0.read("k"), let _v = $0.read("v") else { return }
    let q = graph.variable(Tensor<Float16>(from: _q).toGPU(0))
    let k = graph.variable(Tensor<Float16>(from: _k).toGPU(0))
    let v = graph.variable(Tensor<Float16>(from: _v).toGPU(0))
    let scaledDotProductAttention = ScaledDotProductAttention(scale: 1.0 / Float(64).squareRoot())
    let out = scaledDotProductAttention(inputs: q, k, v)[0].as(of: Float16.self)
    debugPrint(out)
    let q2 = (1.0 / Float(64).squareRoot()) * q
    let scaledDotProductAttention2 = ScaledDotProductAttention(scale: 1)
    let out2 = scaledDotProductAttention2(inputs: q2, k, v)[0].as(of: Float16.self)
    debugPrint(out2)
  }
}

The reprod_tensor.sqlite3 is attached here.
reprod_tensor.split.sqlite3.zip
reprod_tensor.split.sqlite3.z01.zip

(Please rename the sqlite3.z01.zip file to sqlite3.z01 to workaround GitHub file size limitation).

how to use this flash-attention in python code ?

Hi ,thank you for implement flash-attention in MPS , it can be run flash-attention on Mac .
But no document to say how to use it in python or pytorch code ?

I want to use it to speed up stable diffusion model inference time on Mac . I know that run Stable diffusion model on Mac M2 is convert pytorch weight to Coreml, and it only run but can not enable edit any code..

Would you show me how to use this flash-attention in stable diffusion project ??

Undefined symbols error

Hi, when I try to compile the GEMM kernel, I get an error: Undefined symbol(s) for architecture 'air64':\n '@air.simdgroup_async_copy_2d.p1i8.p3i8', referenced from:\n _Z10_gemm_implIfEvPU9MTLdeviceT_S2_S2_PU14MTLthreadgroupS0_Dv3_jtt in program_sourc

I've made sure to install Xcode 14.2 and follow your instructions, but it doesn't seem to know what that asm instruction is. I'm currently on an M1 Pro Macbook pro.

Guidelines for modifying H3 with metal-flash-attention

Hello Philip,

Great project ! It has been something I have been waiting for some time now.

Can you give me some guideline on how I can replace current flash attention mechanism in H3 with metal-flash-attention ?

Thanks in advance !

M3 Performance

Hi thanks a lot for this really cool library.

I've taken it for a spin on M3, and saw on it that GEMM seemed to perform better on MPS.
Do you have any suggestions as to why ? Any way we could help ?

ETA on Dense FlashAttention ?

Holy shit, this project is amazing.
Dense flash attention will be ultra useful, in running language models and stable diffusion. Whats the ETA and challenges you are facing? Maybe I can help.

Also, maybe do you want to extend this project in making a high level framework like Triton for Metal?

Also, folks from Apple are looking at your project closely!! Amazing work!

simdgroup_async issues - Xcode Version 15.0.1 (15A507) / M3 Max 14.1.2 (23B2091)

metal_config in the toolchain doesn't mention HAVE_SIMDGROUP_FUTURE, seems the headers referred to here: dougallj/applegpu#28

were removed altogether from newer versions of Xcode (using 15.0.1 / 15A507)

I was able to find matching strings from the above github issue inside of libapplegpu-nt.dylib with the 15.0.1 toolchain:

objdump --disassemble --demangle /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/metal/macos/lib/libapplegpu-nt.dylib > /tmp/symbols.txt

127b704: 08 91 16 91 add x8, x8, #1444 ; literal pool for: "air.simdgroup_async_copy_1d"
127b708: 60 f6 04 f9 str x0, [x19, #2536]
127b70c: 68 e2 04 f9 str x8, [x19, #2496]
127b710: 28 08 00 b0 adrp x8, 261 ; 0x1380000
127b714: 08 01 17 91 add x8, x8, #1472 ; literal pool for: "air.simdgroup_async_copy_2d"

I didn't pursue this further though to see if things can still be patched up or if the functions are still usable & correct.

Is there interest / would there be positive reception to a PR using alternative read & write mechanisms in lieu of simdgroup_async?

Thanks!

`bfloat16` support

Hi!

I was wondering if you would be interested in adding bf16 support to MFA or at least the GEMM kernels? For mlx Apple defined a custom type: https://github.com/ml-explore/mlx/blob/76c919b4ecf0cccaa1cfef214d12be0ad71485cc/mlx/backend/metal/kernels/bf16.h (MIT licensed), so I understand supporting this is not easy and maybe not even desirable because it's not a native type and performance is not great anyway.

Btw, I came here via huggingface/candle. It uses libMFA for matmul and FA.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.