├── README.md └── flash_attn_gqa.py /README.md: -------------------------------------------------------------------------------- 1 | # Direct Addressing GQA 2 | 3 | This repository implements **Direct Addressing GQA** to optimize memory usage in Transformer models by avoiding the doubling of memory requirements for key-value (KV) pairs caused by `repeat_kv`. 4 | 5 | This approach is based on Triton's fused attention tutorial. 6 | 7 | ## Features 8 | 9 | - **Memory Efficiency**: Reduces KV memory requirements with direct addressing. 10 | - **Fused Attention**: Utilizes Triton's fused attention for improved performance. 11 | 12 | ## TODO 13 | - [ ] support N_CTX < 128, pls note current model only works for N_CTX = 128 * int(). That's because of the bug in fused-attention tutorial. You can choose right padding to 128 * int() to avoid such bug. 14 | -------------------------------------------------------------------------------- /flash_attn_gqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on triton tutorial, 3 | for gqa 4 | ''' 5 | 6 | import pytest 7 | import torch 8 | 9 | import triton 10 | import triton.language as tl 11 | 12 | 13 | 14 | def is_hip(): 15 | return triton.runtime.driver.active.get_current_target().backend == "hip" 16 | 17 | 18 | @triton.jit 19 | def _attn_fwd_inner(acc, l_i, m_i, q, # 20 | K_block_ptr, V_block_ptr, # 21 | start_m, qk_scale, # 22 | BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # 23 | STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # 24 | N_CTX: tl.constexpr, fp8_v: tl.constexpr): 25 | # range of values handled by this stage 26 | if STAGE == 1: 27 | lo, hi = 0, start_m * BLOCK_M 28 | elif STAGE == 2: 29 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 30 | lo = tl.multiple_of(lo, BLOCK_M) 31 | # causal = False 32 | else: 33 | lo, hi = 0, N_CTX 34 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 35 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 36 | # loop over k, v and update accumulator 37 | for start_n in range(lo, hi, BLOCK_N): 38 | start_n = tl.multiple_of(start_n, BLOCK_N) 39 | # -- compute qk ---- 40 | k = tl.load(K_block_ptr) 41 | qk = tl.dot(q, k) 42 | if STAGE == 2: 43 | mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 44 | qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) 45 | m_ij = tl.maximum(m_i, tl.max(qk, 1)) 46 | qk -= m_ij[:, None] 47 | else: 48 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 49 | qk = qk * qk_scale - m_ij[:, None] 50 | p = tl.math.exp2(qk) 51 | l_ij = tl.sum(p, 1) 52 | # -- update m_i and l_i 53 | alpha = tl.math.exp2(m_i - m_ij) 54 | l_i = l_i * alpha + l_ij 55 | # -- update output accumulator -- 56 | acc = acc * alpha[:, None] 57 | # update acc 58 | v = tl.load(V_block_ptr) 59 | if fp8_v: 60 | p = p.to(tl.float8e5) 61 | else: 62 | p = p.to(tl.float16) 63 | acc = tl.dot(p, v, acc) 64 | # update m_i and l_i 65 | m_i = m_ij 66 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 67 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 68 | return acc, l_i, m_i 69 | 70 | 71 | # We don't run auto-tuning every time to keep the tutorial fast. Keeping 72 | # the code below and commenting out the equivalent parameters is convenient for 73 | # re-tuning. 74 | configs = [ 75 | triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ 76 | for BM in [64, 128]\ 77 | for BN in [32, 64]\ 78 | for s in ([1] if is_hip() else [3, 4, 7])\ 79 | for w in [4, 8]\ 80 | ] 81 | 82 | 83 | def keep(conf): 84 | BLOCK_M = conf.kwargs["BLOCK_M"] 85 | BLOCK_N = conf.kwargs["BLOCK_N"] 86 | if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: 87 | return False 88 | return True 89 | 90 | 91 | @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) 92 | @triton.jit 93 | def _attn_fwd(Q, K, V, sm_scale, M, Out, # 94 | stride_qz, stride_qh, stride_qm, stride_qk, # 95 | stride_kz, stride_kh, stride_kn, stride_kk, # 96 | stride_vz, stride_vh, stride_vk, stride_vn, # 97 | stride_oz, stride_oh, stride_om, stride_on, # 98 | Z, H, N_CTX, # 99 | HEAD_DIM: tl.constexpr, # 100 | KV_GROUP: tl.constexpr, 101 | BLOCK_M: tl.constexpr, # 102 | BLOCK_N: tl.constexpr, # 103 | STAGE: tl.constexpr # 104 | ): 105 | tl.static_assert(BLOCK_N <= HEAD_DIM) 106 | start_m = tl.program_id(0) 107 | off_hz = tl.program_id(1) 108 | off_z = off_hz // H # bsz 109 | off_h = off_hz % H # head 110 | off_kvh = off_h // KV_GROUP 111 | qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh 112 | groupvk_offset = off_z.to(tl.int64) * stride_kz + off_kvh.to(tl.int64) * stride_kh 113 | 114 | # block pointers 115 | Q_block_ptr = tl.make_block_ptr( 116 | base=Q + qvk_offset, 117 | shape=(N_CTX, HEAD_DIM), 118 | strides=(stride_qm, stride_qk), 119 | offsets=(start_m * BLOCK_M, 0), 120 | block_shape=(BLOCK_M, HEAD_DIM), 121 | order=(1, 0), 122 | ) 123 | v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) 124 | V_block_ptr = tl.make_block_ptr( 125 | base=V + groupvk_offset, 126 | shape=(N_CTX, HEAD_DIM), 127 | strides=(stride_vk, stride_vn), 128 | offsets=(0, 0), 129 | block_shape=(BLOCK_N, HEAD_DIM), 130 | order=v_order, 131 | ) 132 | K_block_ptr = tl.make_block_ptr( 133 | base=K + groupvk_offset, 134 | shape=(HEAD_DIM, N_CTX), 135 | strides=(stride_kk, stride_kn), 136 | offsets=(0, 0), 137 | block_shape=(HEAD_DIM, BLOCK_N), 138 | order=(0, 1), 139 | ) 140 | O_block_ptr = tl.make_block_ptr( 141 | base=Out + qvk_offset, 142 | shape=(N_CTX, HEAD_DIM), 143 | strides=(stride_om, stride_on), 144 | offsets=(start_m * BLOCK_M, 0), 145 | block_shape=(BLOCK_M, HEAD_DIM), 146 | order=(1, 0), 147 | ) 148 | # initialize offsets 149 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 150 | offs_n = tl.arange(0, BLOCK_N) 151 | # initialize pointer to m and l 152 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 153 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 154 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 155 | # load scales 156 | qk_scale = sm_scale 157 | qk_scale *= 1.44269504 # 1/log(2) 158 | # load q: it will stay in SRAM throughout 159 | q = tl.load(Q_block_ptr) 160 | # stage 1: off-band 161 | # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 162 | # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 163 | if STAGE & 1: 164 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 165 | start_m, qk_scale, # 166 | BLOCK_M, HEAD_DIM, BLOCK_N, # 167 | 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # 168 | ) 169 | # stage 2: on-band 170 | if STAGE & 2: 171 | # barrier makes it easier for compielr to schedule the 172 | # two loops independently 173 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 174 | start_m, qk_scale, # 175 | BLOCK_M, HEAD_DIM, BLOCK_N, # 176 | 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # 177 | ) 178 | # epilogue 179 | m_i += tl.math.log2(l_i) # now, m no more means max, instead, it means [logsumexp] 180 | acc = acc / l_i[:, None] 181 | m_ptrs = M + off_hz * N_CTX + offs_m 182 | tl.store(m_ptrs, m_i) 183 | tl.store(O_block_ptr, acc.to(Out.type.element_ty)) 184 | 185 | 186 | @triton.jit 187 | def _attn_bwd_preprocess(O, DO, # 188 | Delta, # 189 | Z, H, N_CTX, # 190 | BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # 191 | ): 192 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 193 | off_hz = tl.program_id(1) 194 | off_n = tl.arange(0, HEAD_DIM) 195 | # load 196 | o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) 197 | do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) 198 | delta = tl.sum(o * do, axis=1) 199 | # write-back 200 | tl.store(Delta + off_hz * N_CTX + off_m, delta) 201 | 202 | 203 | # The main inner-loop logic for computing dK and dV. 204 | @triton.jit 205 | def _attn_bwd_dkdv(dk, dv, # 206 | Q, k, v, sm_scale, # 207 | DO, # 208 | M, D, # 209 | # shared by Q/K/V/DO. 210 | stride_tok, stride_d, # 211 | H, N_CTX, BLOCK_M1: tl.constexpr, # 212 | BLOCK_N1: tl.constexpr, # 213 | HEAD_DIM: tl.constexpr, # 214 | # Filled in by the wrapper. 215 | start_n, start_m, num_steps, # 216 | MASK: tl.constexpr): 217 | offs_m = start_m + tl.arange(0, BLOCK_M1) # q's ctx len 218 | offs_n = start_n + tl.arange(0, BLOCK_N1) # k's ctx len 219 | offs_k = tl.arange(0, HEAD_DIM) 220 | qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 221 | do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 222 | # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 223 | tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 224 | curr_m = start_m 225 | step_m = BLOCK_M1 226 | for blk_idx in range(num_steps): 227 | qT = tl.load(qT_ptrs) # d x nlen 228 | # Load m before computing qk to reduce pipeline stall. 229 | offs_m = curr_m + tl.arange(0, BLOCK_M1) 230 | m = tl.load(M + offs_m) 231 | qkT = tl.dot(k, qT) # nlen, nlen Transpose 232 | pT = tl.math.exp2(qkT - m[None, :]) 233 | # Autoregressive masking. 234 | if MASK: 235 | mask = (offs_m[None, :] >= offs_n[:, None]) 236 | pT = tl.where(mask, pT, 0.0) 237 | do = tl.load(do_ptrs) 238 | # Compute dV. 239 | ppT = pT 240 | ppT = ppT.to(tl.float16) 241 | dv += tl.dot(ppT, do) 242 | # D (= delta) is pre-divided by ds_scale. 243 | Di = tl.load(D + offs_m) 244 | # compute dLSE 245 | # Compute dP and dS. 246 | dpT = tl.dot(v, tl.trans(do)).to(tl.float32) 247 | dsT = pT * (dpT - Di[None, :]) 248 | dsT = dsT.to(tl.float16) 249 | dk += tl.dot(dsT, tl.trans(qT)) 250 | # Increment pointers. 251 | curr_m += step_m 252 | qT_ptrs += step_m * stride_tok 253 | do_ptrs += step_m * stride_tok 254 | return dk, dv 255 | 256 | 257 | # the main inner-loop logic for computing dQ 258 | @triton.jit 259 | def _attn_bwd_dq(dq, q, K, V, # 260 | do, m, D, 261 | # shared by Q/K/V/DO. 262 | stride_tok, stride_d, # 263 | H, N_CTX, # 264 | BLOCK_M2: tl.constexpr, # 265 | BLOCK_N2: tl.constexpr, # 266 | HEAD_DIM: tl.constexpr, 267 | # Filled in by the wrapper. 268 | start_m, start_n, num_steps, # 269 | MASK: tl.constexpr): 270 | offs_m = start_m + tl.arange(0, BLOCK_M2) 271 | offs_n = start_n + tl.arange(0, BLOCK_N2) 272 | offs_k = tl.arange(0, HEAD_DIM) 273 | kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 274 | vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 275 | # D (= delta) is pre-divided by ds_scale. 276 | Di = tl.load(D + offs_m) 277 | # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 278 | tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 279 | curr_n = start_n 280 | step_n = BLOCK_N2 281 | for blk_idx in range(num_steps): 282 | kT = tl.load(kT_ptrs) 283 | vT = tl.load(vT_ptrs) 284 | qk = tl.dot(q, kT) 285 | p = tl.math.exp2(qk - m) 286 | # Autoregressive masking. 287 | if MASK: 288 | offs_n = curr_n + tl.arange(0, BLOCK_N2) 289 | mask = (offs_m[:, None] >= offs_n[None, :]) 290 | p = tl.where(mask, p, 0.0) 291 | # Compute dP and dS. 292 | dp = tl.dot(do, vT).to(tl.float32) 293 | ds = p * (dp - Di[:, None]) 294 | ds = ds.to(tl.float16) 295 | # Compute dQ. 296 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 297 | dq += tl.dot(ds, tl.trans(kT)) 298 | # Increment pointers. 299 | curr_n += step_n 300 | kT_ptrs += step_n * stride_tok 301 | vT_ptrs += step_n * stride_tok 302 | return dq 303 | 304 | 305 | @triton.jit 306 | def _attn_bwd(Q, K, V, sm_scale, # 307 | DO, # 308 | DQ, DK, DV, # 309 | M, D, 310 | # shared by Q/K/V/DO. 311 | stride_z, stride_h, stride_tok, stride_d, # 312 | stride_kz, stride_kh, 313 | H, N_CTX, # 314 | BLOCK_M1: tl.constexpr, # 315 | BLOCK_N1: tl.constexpr, # 316 | BLOCK_M2: tl.constexpr, # 317 | BLOCK_N2: tl.constexpr, # 318 | BLK_SLICE_FACTOR: tl.constexpr, # 319 | HEAD_DIM: tl.constexpr, 320 | KV_GROUP: tl.constexpr,): 321 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 322 | 323 | bhid = tl.program_id(2) 324 | off_chz = (bhid * N_CTX).to(tl.int64) 325 | adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) 326 | kv_adj = (stride_kh * ((bhid % H) // KV_GROUP) + stride_kz * (bhid // H)).to(tl.int64) 327 | pid = tl.program_id(0) 328 | 329 | # offset pointers for batch/head 330 | Q += adj 331 | K += kv_adj 332 | V += kv_adj 333 | DO += adj 334 | DQ += adj 335 | # NOTE: if dk, dv use direct address, it will cause race. 336 | DK += adj 337 | DV += adj 338 | M += off_chz 339 | D += off_chz 340 | 341 | 342 | # load scales 343 | offs_k = tl.arange(0, HEAD_DIM) 344 | 345 | start_n = pid * BLOCK_N1 346 | start_m = start_n 347 | 348 | MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 349 | offs_n = start_n + tl.arange(0, BLOCK_N1) 350 | 351 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 352 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 353 | 354 | # load K and V: they stay in SRAM throughout the inner loop. 355 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 356 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 357 | 358 | num_steps = BLOCK_N1 // MASK_BLOCK_M1 359 | 360 | dk, dv = _attn_bwd_dkdv(dk, dv, # 361 | Q, k, v, sm_scale, # 362 | DO, # 363 | M, D, # 364 | stride_tok, stride_d, # 365 | H, N_CTX, # 366 | MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # 367 | start_n, start_m, num_steps, # 368 | MASK=True # 369 | ) 370 | 371 | start_m += num_steps * MASK_BLOCK_M1 372 | num_steps = (N_CTX - start_m) // BLOCK_M1 373 | 374 | # Compute dK and dV for non-masked blocks. 375 | dk, dv = _attn_bwd_dkdv( # 376 | dk, dv, # 377 | Q, k, v, sm_scale, # 378 | DO, # 379 | M, D, # 380 | stride_tok, stride_d, # 381 | H, N_CTX, # 382 | BLOCK_M1, BLOCK_N1, HEAD_DIM, # 383 | start_n, start_m, num_steps, # 384 | MASK=False # 385 | ) 386 | 387 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 388 | # dv += tl.load(dv_ptrs) 389 | 390 | tl.store(dv_ptrs, dv) 391 | 392 | # Write back dK. 393 | dk *= sm_scale 394 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 395 | # prev_dk = tl.load(dk_ptrs) 396 | # dk += prev_dk 397 | tl.store(dk_ptrs, dk) 398 | 399 | # THIS BLOCK DOES DQ: 400 | start_m = pid * BLOCK_M2 401 | end_n = start_m + BLOCK_M2 402 | 403 | MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 404 | offs_m = start_m + tl.arange(0, BLOCK_M2) 405 | 406 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 407 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) 408 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 409 | 410 | m = tl.load(M + offs_m) 411 | m = m[:, None] 412 | 413 | # Compute dQ for masked (diagonal) blocks. 414 | # NOTE: This code scans each row of QK^T backward (from right to left, 415 | # but inside each call to _attn_bwd_dq, from left to right), but that's 416 | # not due to anything important. I just wanted to reuse the loop 417 | # structure for dK & dV above as much as possible. 418 | num_steps = BLOCK_M2 // MASK_BLOCK_N2 419 | dq = _attn_bwd_dq(dq, q, K, V, # 420 | do, m, D, # 421 | stride_tok, stride_d, # 422 | H, N_CTX, # 423 | BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # 424 | start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # 425 | MASK=True # 426 | ) 427 | end_n -= num_steps * MASK_BLOCK_N2 428 | # stage 2 429 | num_steps = end_n // BLOCK_N2 430 | dq = _attn_bwd_dq(dq, q, K, V, # 431 | do, m, D, # 432 | stride_tok, stride_d, # 433 | H, N_CTX, # 434 | BLOCK_M2, BLOCK_N2, HEAD_DIM, # 435 | start_m, end_n - num_steps * BLOCK_N2, num_steps, # 436 | MASK=False # 437 | ) 438 | # Write back dQ. 439 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 440 | dq *= LN2 441 | tl.store(dq_ptrs, dq) 442 | 443 | 444 | class _attention(torch.autograd.Function): 445 | 446 | @staticmethod 447 | def forward(ctx, q, k, v, causal, sm_scale): 448 | # shape constraints 449 | HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] 450 | # when v is in float8_e5m2 it is transposed. 451 | HEAD_DIM_V = v.shape[-1] 452 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 453 | assert HEAD_DIM_K in {16, 32, 64, 128, 256} 454 | o = torch.empty_like(q) 455 | stage = 3 if causal else 1 456 | extra_kern_args = {} 457 | # Tuning for AMD target 458 | if is_hip(): 459 | waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 460 | extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} 461 | 462 | kv_group = q.shape[1] // k.shape[1] 463 | 464 | grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) 465 | M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) 466 | _attn_fwd[grid]( 467 | q, k, v, sm_scale, M, o, # 468 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 469 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), # 470 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), # 471 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), # 472 | q.shape[0], q.shape[1], # 473 | N_CTX=q.shape[2], # 474 | HEAD_DIM=HEAD_DIM_K, # 475 | KV_GROUP=kv_group, 476 | STAGE=stage, # 477 | **extra_kern_args) 478 | 479 | ctx.save_for_backward(q, k, v, o, M) 480 | ctx.grid = grid 481 | ctx.sm_scale = sm_scale 482 | ctx.HEAD_DIM = HEAD_DIM_K 483 | ctx.causal = causal 484 | return o 485 | 486 | @staticmethod 487 | def backward(ctx, do): 488 | q, k, v, o, M = ctx.saved_tensors 489 | kv_group = q.shape[1] // k.shape[1] 490 | assert do.is_contiguous() 491 | # == k.stride() == v.stride() 492 | assert q.stride() == o.stride() == do.stride() 493 | dq = torch.empty_like(q) 494 | dk_size = (k.size(0), k.size(1) * kv_group, k.size(2), k.size(3)) 495 | dv_size = (v.size(0), v.size(1) * kv_group, v.size(2), v.size(3)) 496 | dk = k.new_zeros(dk_size) # torch.empty_like(k) 497 | dv = v.new_zeros(dv_size) #torch.empty_like(v) 498 | BATCH, N_HEAD, N_CTX = q.shape[:3] 499 | PRE_BLOCK = 128 500 | NUM_WARPS, NUM_STAGES = 4, 5 501 | BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 502 | BLK_SLICE_FACTOR = 2 503 | RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 504 | arg_k = k 505 | arg_k = arg_k * (ctx.sm_scale * RCP_LN2) 506 | PRE_BLOCK = 128 507 | assert N_CTX % PRE_BLOCK == 0 508 | pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) 509 | delta = torch.empty_like(M) 510 | 511 | _attn_bwd_preprocess[pre_grid]( 512 | o, do, # 513 | delta, # 514 | BATCH, N_HEAD, N_CTX, # 515 | BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # 516 | ) 517 | 518 | grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) 519 | _attn_bwd[grid]( 520 | q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # 521 | M, delta, # 522 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 523 | k.stride(0), k.stride(1), 524 | N_HEAD, N_CTX, # 525 | BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # 526 | BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # 527 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 528 | HEAD_DIM=ctx.HEAD_DIM, # 529 | KV_GROUP=kv_group, 530 | num_warps=NUM_WARPS, # 531 | num_stages=NUM_STAGES # 532 | ) 533 | dk_size = (k.size(0), k.size(1), kv_group, k.size(2), k.size(3)) 534 | dv_size = (v.size(0), v.size(1), kv_group, v.size(2), v.size(3)) 535 | return dq, dk.view(dk_size).sum(dim=2), dv.view(dv_size).sum(dim=2), None, None 536 | 537 | 538 | attention = _attention.apply 539 | 540 | 541 | @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, KV_GROUP", [(32, 32, 1024, 64, 2)]) 542 | @pytest.mark.parametrize("causal", [True]) 543 | def test_op(Z, H, N_CTX, HEAD_DIM, KV_GROUP, causal, dtype=torch.float16): 544 | from transformers.models.llama.modeling_llama import repeat_kv 545 | 546 | KV_NUM = H // KV_GROUP 547 | 548 | torch.manual_seed(20) 549 | q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 550 | k = (torch.empty((Z, KV_NUM, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 551 | v = (torch.empty((Z, KV_NUM, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 552 | repeat_k = repeat_kv(k, KV_GROUP) 553 | repeat_v = repeat_kv(v, KV_GROUP) 554 | 555 | sm_scale = 0.5 556 | dout = torch.randn_like(q) 557 | # reference implementation 558 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 559 | p = torch.matmul(q, repeat_k.transpose(2, 3)) * sm_scale 560 | if causal: 561 | p[:, :, M == 0] = float("-inf") 562 | p = torch.softmax(p.float(), dim=-1).half() 563 | # p = torch.exp(p) 564 | ref_out = torch.matmul(p, repeat_v) 565 | ref_out.backward(dout) 566 | ref_dv, v.grad = v.grad.clone(), None 567 | ref_dk, k.grad = k.grad.clone(), None 568 | ref_dq, q.grad = q.grad.clone(), None 569 | # triton implementation 570 | tri_out = attention(q, k, v, causal, sm_scale).half() 571 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 572 | 573 | tri_out.backward(dout) 574 | tri_dv, v.grad = v.grad.clone(), None 575 | tri_dk, k.grad = k.grad.clone(), None 576 | tri_dq, q.grad = q.grad.clone(), None 577 | # compare 578 | rtol = 0.0 579 | # Relative tolerance workaround for known hardware limitation of MI200 GPU. 580 | # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 581 | if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": 582 | rtol = 1e-2 583 | assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) 584 | assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) 585 | assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) 586 | 587 | 588 | 589 | try: 590 | from flash_attn.flash_attn_interface import \ 591 | flash_attn_qkvpacked_func as flash_attn_func 592 | HAS_FLASH = True 593 | except BaseException: 594 | HAS_FLASH = False 595 | 596 | TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') 597 | BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 598 | # vary seq length for fixed head and batch=4 599 | configs = [] 600 | for mode in ["fwd", "bwd"]: 601 | for causal in [True, False]: 602 | if mode == "bwd" and not causal: 603 | continue 604 | configs.append( 605 | triton.testing.Benchmark( 606 | x_names=["N_CTX"], 607 | x_vals=[2**i for i in range(10, 15)], 608 | line_arg="provider", 609 | line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + 610 | (["flash"] if HAS_FLASH else []), 611 | line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + 612 | (["Flash-2"] if HAS_FLASH else []), 613 | styles=[("red", "-"), ("blue", "-"), ("green", "-")], 614 | ylabel="ms", 615 | plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", 616 | args={ 617 | "H": N_HEADS, 618 | "BATCH": BATCH, 619 | "HEAD_DIM": HEAD_DIM, 620 | "mode": mode, 621 | "causal": causal, 622 | }, 623 | )) 624 | 625 | 626 | @triton.testing.perf_report(configs) 627 | def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): 628 | assert mode in ["fwd", "bwd"] 629 | warmup = 25 630 | rep = 100 631 | dtype = torch.float16 632 | if "triton" in provider: 633 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 634 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 635 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 636 | if mode == "fwd" and "fp8" in provider: 637 | q = q.to(torch.float8_e5m2) 638 | k = k.to(torch.float8_e5m2) 639 | v = v.permute(0, 1, 3, 2).contiguous() 640 | v = v.permute(0, 1, 3, 2) 641 | v = v.to(torch.float8_e5m2) 642 | sm_scale = 1.3 643 | fn = lambda: attention(q, k, v, causal, sm_scale) 644 | if mode == "bwd": 645 | o = fn() 646 | do = torch.randn_like(o) 647 | fn = lambda: o.backward(do, retain_graph=True) 648 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 649 | if provider == "flash": 650 | qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 651 | fn = lambda: flash_attn_func(qkv, causal=causal) 652 | if mode == "bwd": 653 | o = fn() 654 | do = torch.randn_like(o) 655 | fn = lambda: o.backward(do, retain_graph=True) 656 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 657 | flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM 658 | total_flops = 2 * flops_per_matmul 659 | if causal: 660 | total_flops *= 0.5 661 | if mode == "bwd": 662 | total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) 663 | return total_flops / ms * 1e-9 664 | 665 | 666 | if __name__ == "__main__": 667 | # only works on post-Ampere GPUs right now 668 | bench_flash_attention.run(save_path=".", print_data=True) 669 | --------------------------------------------------------------------------------