def blockwise_flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_chunk_size: int,
k_chunk_size: int,
softmax_scale,
dropout_p=0,
causal=True,
return_softmax=True
):
assert q.shape[1] % q_chunk_size == 0
assert k.shape[1] % k_chunk_size == 0
num_q_chunk = q.shape[1] // q_chunk_size
num_k_chunk = k.shape[1] // k_chunk_size
batch,seqlen,head_dim,num_head = q.shape
block_out = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_lse = torch.empty((batch,seqlen,head_dim,1), dtype=torch.float32, device=q.device)
for i in range(num_q_chunk):
q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
out_i = None
lse_i = None
for j in range(num_k_chunk-1,-1,-1):
if j > i and causal:
continue
k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
out_ij, _, _, _, _, lse_ij, _, _ = _flash_attn_forward(
q_i,
k_j,
v_j,
dropout_p,
softmax_scale,
causal=causal and j == i,
return_softmax=return_softmax and dropout_p > 0
)
out_i, lse_i = update_out_and_lse(out_i, lse_i, out_ij, lse_ij)
block_out[:, i * q_chunk_size: (i + 1) * q_chunk_size] = out_i
block_lse[:, i * q_chunk_size: (i + 1) * q_chunk_size] = lse_i
return block_out, block_lse.squeeze(dim=-1).transpose(-1,-2)
def blockwise_flash_attn_backward(
dout,
q,
k,
v,
out,
q_chunk_size,
k_chunk_size,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
dropout_p,
causal=True,
rng_state=None
):
assert q.shape[1] % q_chunk_size == 0
assert k.shape[1] % k_chunk_size == 0
num_q_chunk = q.shape[1] // q_chunk_size
num_k_chunk = k.shape[1] // k_chunk_size
temp_dq_buffer = torch.empty(q[:,:q_chunk_size].shape, dtype=q.dtype, device=q.device)
temp_dk_buffer = torch.empty(k[:,:k_chunk_size].shape, dtype=k.dtype, device=k.device)
temp_dv_buffer = torch.empty(v[:,:k_chunk_size].shape, dtype=v.dtype, device=v.device)
for i in range(num_q_chunk):
q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
dout_i = dout[:,i * q_chunk_size: (i + 1) * q_chunk_size]
out_i = out[:,i * q_chunk_size: (i + 1) * q_chunk_size]
softmax_lse_i = softmax_lse[:,:,i * q_chunk_size: (i + 1) * q_chunk_size]
q_i = q_i.contiguous()
dout_i = dout_i.contiguous()
out_i = out_i.contiguous()
softmax_lse_i = softmax_lse_i.contiguous()
for j in range(num_k_chunk):
k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
k_j = k_j.contiguous()
v_j = v_j.contiguous()
if j > i and causal:
continue
_flash_attn_backward(
dout_i,
q_i,
k_j,
v_j,
out_i,
softmax_lse_i,
temp_dq_buffer,
temp_dk_buffer,
temp_dv_buffer,
dropout_p,
softmax_scale,
causal = causal and j == i,
rng_state=rng_state,
)
# update dq dk dv
dq[:,i * q_chunk_size: (i + 1) * q_chunk_size] += temp_dq_buffer
dk[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dk_buffer
dv[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dv_buffer
分别替换ring_flash_attn_forward 中的_flash_attn_forward,和ring_flash_attn_backward中的_flash_attn_backward
##############################
# forward:
##############################
out: max 2.896484375, mean 0.0203094482421875
lse: max 10.417832374572754, mean 9.204237937927246
out diff:
[0] max 0.00048828125, mean 8.881092071533203e-06
[1] max 0.0001220703125, mean 7.450580596923828e-06
[2] max 0.0001220703125, mean 5.9604644775390625e-06
[3] max 6.103515625e-05, mean 5.066394805908203e-06
[4] max 6.103515625e-05, mean 4.5299530029296875e-06
[5] max 6.103515625e-05, mean 4.112720489501953e-06
[6] max 6.103515625e-05, mean 3.814697265625e-06
[7] max 6.103515625e-05, mean 3.516674041748047e-06
lse diff:
[0] max 9.5367431640625e-07, mean 1.645181413323371e-07
[1] max 9.5367431640625e-07, mean 2.641230878452916e-07
[2] max 1.9073486328125e-06, mean 3.0044466825529526e-07
[3] max 1.9073486328125e-06, mean 3.3890827921823075e-07
[4] max 1.9073486328125e-06, mean 3.8137659430503845e-07
[5] max 1.9073486328125e-06, mean 4.0913002408160537e-07
[6] max 1.9073486328125e-06, mean 4.272908142866072e-07
[7] max 1.9073486328125e-06, mean 4.6798959374427795e-07
##############################
# backward:
##############################
load_dq:
[0] max 2.783203125, mean 0.052520751953125
[1] max 0.3310546875, mean 0.02398681640625
[2] max 0.2083740234375, mean 0.0184478759765625
[3] max 0.1162109375, mean 0.0155792236328125
[4] max 0.13330078125, mean 0.01374053955078125
[5] max 0.1204833984375, mean 0.01241302490234375
[6] max 0.11260986328125, mean 0.0114288330078125
[7] max 0.0775146484375, mean 0.01064300537109375
dq diff:
[0] max 0.005859375, mean 7.49826431274414e-05
[1] max 0.186279296875, mean 0.01239776611328125
[2] max 0.1973876953125, mean 0.01953125
[3] max 0.235107421875, mean 0.0253143310546875
[4] max 0.30615234375, mean 0.0301361083984375
[5] max 0.52392578125, mean 0.03436279296875
[6] max 0.56689453125, mean 0.038177490234375
[7] max 0.3955078125, mean 0.041748046875
load_dk:
[0] max 2.654296875, mean 0.05340576171875
[1] max 0.256591796875, mean 0.021697998046875
[2] max 0.169921875, mean 0.01535797119140625
[3] max 0.13330078125, mean 0.0116729736328125
[4] max 0.09124755859375, mean 0.0090484619140625
[5] max 0.1158447265625, mean 0.006908416748046875
[6] max 0.050384521484375, mean 0.00492095947265625
[7] max 0.03936767578125, mean 0.002498626708984375
dk diff:
[0] max 0.253173828125, mean 0.03192138671875
[1] max 0.16845703125, mean 0.0232696533203125
[2] max 0.130126953125, mean 0.017364501953125
[3] max 0.1097412109375, mean 0.012786865234375
[4] max 0.10797119140625, mean 0.00893402099609375
[5] max 0.049530029296875, mean 0.005580902099609375
[6] max 0.039337158203125, mean 0.002498626708984375
[7] max 1.52587890625e-05, mean 3.5762786865234375e-07
load_dv:
[0] max 5.89453125, mean 0.05450439453125
[1] max 0.1951904296875, mean 0.021484375
[2] max 0.11883544921875, mean 0.01525115966796875
[3] max 0.10003662109375, mean 0.01158905029296875
[4] max 0.07550048828125, mean 0.00901031494140625
[5] max 0.06658935546875, mean 0.006816864013671875
[6] max 0.041015625, mean 0.00492095947265625
[7] max 0.041961669921875, mean 0.002475738525390625
dv diff:
[0] max 0.3232421875, mean 0.042572021484375
[1] max 0.21240234375, mean 0.03094482421875
[2] max 0.1527099609375, mean 0.0223236083984375
[3] max 0.1075439453125, mean 0.015625
[4] max 0.08245849609375, mean 0.010223388671875
[5] max 0.0447998046875, mean 0.005950927734375
[6] max 0.0419921875, mean 0.002475738525390625
[7] max 3.0517578125e-05, mean 3.5762786865234375e-07