├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── benchmark ├── bench_flash_mla.py └── visualize.py ├── csrc ├── flash_api.cpp └── kernels │ ├── config.h │ ├── get_mla_metadata.cu │ ├── get_mla_metadata.h │ ├── mla_combine.cu │ ├── mla_combine.h │ ├── params.h │ ├── splitkv_mla.cu │ ├── splitkv_mla.h │ ├── traits.h │ └── utils.h ├── docs ├── 20250422-new-kernel-deep-dive.md └── assets │ └── MLA Kernel Sched.drawio.svg ├── flash_mla ├── __init__.py └── flash_mla_interface.py ├── setup.py └── tests └── test_flash_mla.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | *.so 3 | *.egg-info/ 4 | __pycache__/ 5 | dist/ 6 | *perf.csv 7 | *.png 8 | /.vscode 9 | compile_commands.json 10 | .cache 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "csrc/cutlass"] 2 | path = csrc/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashMLA 2 | 3 | ## Performance Update (2025.04.22) 4 | 5 | We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 6 | 7 | Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). 8 | 9 | The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. 10 | 11 | ## Introduction 12 | 13 | FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. 14 | 15 | Currently released: 16 | - BF16, FP16 17 | - Paged kvcache with block size of 64 18 | 19 | ## Requirements 20 | 21 | - Hopper GPUs 22 | - CUDA 12.3 and above 23 | - **But we highly recommend 12.8 or above for the best performance** 24 | - PyTorch 2.0 and above 25 | 26 | ## Quick start 27 | 28 | ### Install 29 | 30 | ```bash 31 | python setup.py install 32 | ``` 33 | 34 | ### Benchmark 35 | 36 | ```bash 37 | python tests/test_flash_mla.py 38 | ``` 39 | 40 | It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. 41 | 42 | Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. 43 | 44 | ### Usage 45 | 46 | ```python 47 | from flash_mla import get_mla_metadata, flash_mla_with_kvcache 48 | 49 | tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) 50 | 51 | for i in range(num_layers): 52 | ... 53 | o_i, lse_i = flash_mla_with_kvcache( 54 | q_i, kvcache_i, block_table, cache_seqlens, dv, 55 | tile_scheduler_metadata, num_splits, causal=True, 56 | ) 57 | ... 58 | ``` 59 | 60 | ## Acknowledgement 61 | 62 | FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. 63 | 64 | ## Community Support 65 | 66 | ### MetaX 67 | For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). 68 | 69 | The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) 70 | 71 | 72 | ### Moore Threads 73 | For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). 74 | 75 | The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). 76 | 77 | 78 | ### Hygon DCU 79 | For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). 80 | 81 | The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). 82 | 83 | 84 | ### Intellifusion 85 | For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). 86 | 87 | The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). 88 | 89 | 90 | ### Iluvatar Corex 91 | For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). 92 | 93 | The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) 94 | 95 | 96 | ### AMD Instinct 97 | For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html). 98 | 99 | The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py) 100 | 101 | ## Citation 102 | 103 | ```bibtex 104 | @misc{flashmla2025, 105 | title={FlashMLA: Efficient MLA decoding kernels}, 106 | author={Jiashi Li, Shengyu Liu}, 107 | year={2025}, 108 | publisher = {GitHub}, 109 | howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /benchmark/bench_flash_mla.py: -------------------------------------------------------------------------------- 1 | # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a 2 | import argparse 3 | import math 4 | import random 5 | 6 | import flashinfer 7 | import torch 8 | import triton 9 | import triton.language as tl 10 | 11 | # pip install flashinfer-python 12 | from flash_mla import flash_mla_with_kvcache, get_mla_metadata 13 | 14 | 15 | def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): 16 | query = query.float() 17 | key = key.float() 18 | value = value.float() 19 | key = key.repeat_interleave(h_q // h_kv, dim=0) 20 | value = value.repeat_interleave(h_q // h_kv, dim=0) 21 | attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) 22 | if is_causal: 23 | s_q = query.shape[-2] 24 | s_k = key.shape[-2] 25 | attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) 26 | temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) 27 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 28 | attn_bias.to(query.dtype) 29 | attn_weight += attn_bias 30 | lse = attn_weight.logsumexp(dim=-1) 31 | attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) 32 | return attn_weight @ value, lse 33 | 34 | 35 | @torch.inference_mode() 36 | def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 37 | for i in range(b): 38 | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") 39 | blocked_v = blocked_k[..., :dv] 40 | 41 | def ref_mla(): 42 | out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) 43 | lse = torch.empty(b, h_q, s_q, dtype=torch.float32) 44 | for i in range(b): 45 | begin = i * max_seqlen_pad 46 | end = begin + cache_seqlens[i] 47 | O, LSE = scaled_dot_product_attention( 48 | q[i].transpose(0, 1), 49 | blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), 50 | blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), 51 | h_q, h_kv, 52 | is_causal=causal, 53 | ) 54 | out[i] = O.transpose(0, 1) 55 | lse[i] = LSE 56 | return out, lse 57 | 58 | out_torch, lse_torch = ref_mla() 59 | t = triton.testing.do_bench(ref_mla) 60 | return out_torch, lse_torch, t 61 | 62 | @torch.inference_mode() 63 | def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 64 | for i in range(b): 65 | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") 66 | blocked_v = blocked_k[..., :dv] 67 | 68 | tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) 69 | 70 | def flash_mla(): 71 | return flash_mla_with_kvcache( 72 | q, blocked_k, block_table, cache_seqlens, dv, 73 | tile_scheduler_metadata, num_splits, causal=causal, 74 | ) 75 | 76 | out_flash, lse_flash = flash_mla() 77 | t = triton.testing.do_bench(flash_mla) 78 | return out_flash, lse_flash, t 79 | 80 | 81 | @torch.inference_mode() 82 | def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 83 | 84 | for i in range(b): 85 | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") 86 | 87 | assert d > dv, "mla with rope dim should be larger than no rope dim" 88 | q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() 89 | blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() 90 | 91 | 92 | kv_indptr = [0] 93 | kv_indices = [] 94 | for i in range(b): 95 | seq_len = cache_seqlens[i] 96 | assert seq_len > 0 97 | num_blocks = (seq_len + block_size - 1) // block_size 98 | kv_indices.extend(block_table[i, :num_blocks]) 99 | kv_indptr.append(kv_indptr[-1] + num_blocks) 100 | for seq_len in cache_seqlens[1:]: 101 | kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) 102 | 103 | q_indptr = torch.arange(0, b + 1).int() * s_q 104 | kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) 105 | kv_indices = torch.tensor(kv_indices, dtype=torch.int32) 106 | 107 | mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( 108 | torch.empty(128 * 1024 * 1024, dtype=torch.int8), 109 | backend="fa3" 110 | ) 111 | mla_wrapper.plan( 112 | q_indptr, 113 | kv_indptr, 114 | kv_indices, 115 | cache_seqlens, 116 | h_q, 117 | dv, 118 | d-dv, 119 | block_size, 120 | causal, 121 | 1 / math.sqrt(d), 122 | q.dtype, 123 | blocked_k.dtype, 124 | ) 125 | 126 | def flash_infer(): 127 | output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True) 128 | return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) 129 | 130 | out_flash, lse_flash = flash_infer() 131 | t = triton.testing.do_bench(flash_infer) 132 | return out_flash, lse_flash, t 133 | 134 | 135 | @triton.jit 136 | def _mla_attn_kernel( 137 | Q_nope, 138 | Q_pe, 139 | Kv_c_cache, 140 | K_pe_cache, 141 | Req_to_tokens, 142 | B_seq_len, 143 | O, 144 | sm_scale, 145 | stride_q_nope_bs, 146 | stride_q_nope_h, 147 | stride_q_pe_bs, 148 | stride_q_pe_h, 149 | stride_kv_c_bs, 150 | stride_k_pe_bs, 151 | stride_req_to_tokens_bs, 152 | stride_o_b, 153 | stride_o_h, 154 | stride_o_s, 155 | BLOCK_H: tl.constexpr, 156 | BLOCK_N: tl.constexpr, 157 | NUM_KV_SPLITS: tl.constexpr, 158 | PAGE_SIZE: tl.constexpr, 159 | HEAD_DIM_CKV: tl.constexpr, 160 | HEAD_DIM_KPE: tl.constexpr, 161 | ): 162 | cur_batch = tl.program_id(1) 163 | cur_head_id = tl.program_id(0) 164 | split_kv_id = tl.program_id(2) 165 | 166 | cur_batch_seq_len = tl.load(B_seq_len + cur_batch) 167 | 168 | offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) 169 | cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) 170 | offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] 171 | q_nope = tl.load(Q_nope + offs_q_nope) 172 | 173 | offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) 174 | offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] 175 | q_pe = tl.load(Q_pe + offs_q_pe) 176 | 177 | e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") 178 | e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 179 | acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) 180 | 181 | kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) 182 | split_kv_start = kv_len_per_split * split_kv_id 183 | split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) 184 | 185 | for start_n in range(split_kv_start, split_kv_end, BLOCK_N): 186 | offs_n = start_n + tl.arange(0, BLOCK_N) 187 | kv_page_number = tl.load( 188 | Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, 189 | mask=offs_n < split_kv_end, 190 | other=0, 191 | ) 192 | kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 193 | offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] 194 | k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) 195 | 196 | qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) 197 | 198 | offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] 199 | k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) 200 | 201 | qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) 202 | qk *= sm_scale 203 | 204 | qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) 205 | 206 | v_c = tl.trans(k_c) 207 | 208 | n_e_max = tl.maximum(tl.max(qk, 1), e_max) 209 | re_scale = tl.exp(e_max - n_e_max) 210 | p = tl.exp(qk - n_e_max[:, None]) 211 | acc *= re_scale[:, None] 212 | acc += tl.dot(p.to(v_c.dtype), v_c) 213 | 214 | e_sum = e_sum * re_scale + tl.sum(p, 1) 215 | e_max = n_e_max 216 | offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] 217 | tl.store(O + offs_o, acc / e_sum[:, None]) 218 | offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV 219 | tl.store(O + offs_o_1, e_max + tl.log(e_sum)) 220 | 221 | 222 | def _mla_attn( 223 | q_nope, 224 | q_pe, 225 | kv_c_cache, 226 | k_pe_cache, 227 | attn_logits, 228 | req_to_tokens, 229 | b_seq_len, 230 | num_kv_splits, 231 | sm_scale, 232 | page_size, 233 | ): 234 | batch_size, head_num = q_nope.shape[0], q_nope.shape[1] 235 | head_dim_ckv = q_nope.shape[-1] 236 | head_dim_kpe = q_pe.shape[-1] 237 | 238 | BLOCK_H = 16 239 | BLOCK_N = 64 240 | grid = ( 241 | triton.cdiv(head_num, BLOCK_H), 242 | batch_size, 243 | num_kv_splits, 244 | ) 245 | _mla_attn_kernel[grid]( 246 | q_nope, 247 | q_pe, 248 | kv_c_cache, 249 | k_pe_cache, 250 | req_to_tokens, 251 | b_seq_len, 252 | attn_logits, 253 | sm_scale, 254 | # stride 255 | q_nope.stride(0), 256 | q_nope.stride(1), 257 | q_pe.stride(0), 258 | q_pe.stride(1), 259 | kv_c_cache.stride(-2), 260 | k_pe_cache.stride(-2), 261 | req_to_tokens.stride(0), 262 | attn_logits.stride(0), 263 | attn_logits.stride(1), 264 | attn_logits.stride(2), 265 | BLOCK_H=BLOCK_H, 266 | BLOCK_N=BLOCK_N, 267 | NUM_KV_SPLITS=num_kv_splits, 268 | PAGE_SIZE=page_size, 269 | HEAD_DIM_CKV=head_dim_ckv, 270 | HEAD_DIM_KPE=head_dim_kpe, 271 | ) 272 | 273 | @triton.jit 274 | def _mla_softmax_reducev_kernel( 275 | Logits, 276 | B_seq_len, 277 | O, 278 | stride_l_b, 279 | stride_l_h, 280 | stride_l_s, 281 | stride_o_b, 282 | stride_o_h, 283 | NUM_KV_SPLITS: tl.constexpr, 284 | HEAD_DIM_CKV: tl.constexpr, 285 | ): 286 | cur_batch = tl.program_id(0) 287 | cur_head = tl.program_id(1) 288 | cur_batch_seq_len = tl.load(B_seq_len + cur_batch) 289 | 290 | offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) 291 | 292 | e_sum = 0.0 293 | e_max = -float("inf") 294 | acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) 295 | 296 | offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv 297 | offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV 298 | 299 | for split_kv_id in range(0, NUM_KV_SPLITS): 300 | kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) 301 | split_kv_start = kv_len_per_split * split_kv_id 302 | split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) 303 | 304 | if split_kv_end > split_kv_start: 305 | logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) 306 | logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) 307 | 308 | n_e_max = tl.maximum(logits_1, e_max) 309 | old_scale = tl.exp(e_max - n_e_max) 310 | acc *= old_scale 311 | exp_logic = tl.exp(logits_1 - n_e_max) 312 | acc += exp_logic * logits 313 | 314 | e_sum = e_sum * old_scale + exp_logic 315 | e_max = n_e_max 316 | 317 | tl.store( 318 | O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, 319 | acc / e_sum, 320 | ) 321 | 322 | 323 | def _mla_softmax_reducev( 324 | logits, 325 | o, 326 | b_seq_len, 327 | num_kv_splits, 328 | ): 329 | batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] 330 | grid = (batch_size, head_num) 331 | _mla_softmax_reducev_kernel[grid]( 332 | logits, 333 | b_seq_len, 334 | o, 335 | logits.stride(0), 336 | logits.stride(1), 337 | logits.stride(2), 338 | o.stride(0), 339 | o.stride(1), 340 | NUM_KV_SPLITS=num_kv_splits, 341 | HEAD_DIM_CKV=head_dim_ckv, 342 | num_warps=4, 343 | num_stages=2, 344 | ) 345 | 346 | def mla_decode_triton( 347 | q_nope, 348 | q_pe, 349 | kv_c_cache, 350 | k_pe_cache, 351 | o, 352 | req_to_tokens, 353 | b_seq_len, 354 | attn_logits, 355 | num_kv_splits, 356 | sm_scale, 357 | page_size, 358 | ): 359 | assert num_kv_splits == attn_logits.shape[2] 360 | _mla_attn( 361 | q_nope, 362 | q_pe, 363 | kv_c_cache, 364 | k_pe_cache, 365 | attn_logits, 366 | req_to_tokens, 367 | b_seq_len, 368 | num_kv_splits, 369 | sm_scale, 370 | page_size, 371 | ) 372 | _mla_softmax_reducev( 373 | attn_logits, 374 | o, 375 | b_seq_len, 376 | num_kv_splits, 377 | ) 378 | 379 | 380 | @torch.inference_mode() 381 | def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 382 | 383 | for i in range(b): 384 | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") 385 | blocked_v = blocked_k[..., :dv] 386 | 387 | assert d > dv, "mla with rope dim should be larger than no rope dim" 388 | q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() 389 | blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() 390 | 391 | def flash_mla_triton(): 392 | num_kv_splits = 32 393 | o = torch.empty([b * s_q, h_q, dv]) 394 | attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) 395 | mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size) 396 | return o.view([b, s_q, h_q, dv]) 397 | 398 | out_flash = flash_mla_triton() 399 | t = triton.testing.do_bench(flash_mla_triton) 400 | return out_flash, None, t 401 | 402 | 403 | FUNC_TABLE = { 404 | "torch": run_torch_mla, 405 | "flash_mla": run_flash_mla, 406 | "flash_infer": run_flash_infer, 407 | "flash_mla_triton": run_flash_mla_triton, 408 | } 409 | 410 | def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 411 | print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") 412 | device = torch.device("cuda:0") 413 | torch.set_default_dtype(dtype) 414 | torch.set_default_device(device) 415 | torch.cuda.set_device(device) 416 | torch.manual_seed(0) 417 | random.seed(0) 418 | assert baseline in FUNC_TABLE 419 | assert target in FUNC_TABLE 420 | baseline_func = FUNC_TABLE[baseline] 421 | target_func = FUNC_TABLE[target] 422 | 423 | total_seqlens = cache_seqlens.sum().item() 424 | mean_seqlens = cache_seqlens.float().mean().int().item() 425 | max_seqlen = cache_seqlens.max().item() 426 | max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 427 | # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") 428 | 429 | q = torch.randn(b, s_q, h_q, d) 430 | block_size = 64 431 | block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) 432 | blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) 433 | 434 | out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) 435 | out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) 436 | 437 | torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" 438 | if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]: 439 | # flash_infer has a different lse return value 440 | # flash_mla_triton doesn't return lse 441 | torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" 442 | 443 | FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 444 | bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) 445 | print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") 446 | print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") 447 | return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b 448 | 449 | 450 | def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): 451 | print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") 452 | torch.set_default_dtype(dtype) 453 | device = torch.device("cuda:0") 454 | torch.set_default_device(device) 455 | torch.cuda.set_device(device) 456 | torch.manual_seed(0) 457 | random.seed(0) 458 | assert target in FUNC_TABLE 459 | target_func = FUNC_TABLE[target] 460 | 461 | total_seqlens = cache_seqlens.sum().item() 462 | mean_seqlens = cache_seqlens.float().mean().int().item() 463 | max_seqlen = cache_seqlens.max().item() 464 | max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 465 | # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") 466 | 467 | q = torch.randn(b, s_q, h_q, d) 468 | block_size = 64 469 | block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) 470 | blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) 471 | 472 | out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) 473 | 474 | FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 475 | bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) 476 | print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") 477 | return bytes / 10 ** 6 / perf_b 478 | 479 | 480 | available_targets = [ 481 | "torch", 482 | "flash_mla", 483 | "flash_infer", 484 | "flash_mla_triton", 485 | ] 486 | 487 | shape_configs = [ 488 | {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} 489 | for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] 490 | ] 491 | 492 | 493 | def get_args(): 494 | parser = argparse.ArgumentParser() 495 | parser.add_argument("--baseline", type=str, default="torch") 496 | parser.add_argument("--target", type=str, default="flash_mla") 497 | parser.add_argument("--all", action="store_true") 498 | parser.add_argument("--one", action="store_true") 499 | parser.add_argument("--compare", action="store_true") 500 | args = parser.parse_args() 501 | return args 502 | 503 | 504 | if __name__ == "__main__": 505 | args = get_args() 506 | benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target 507 | with open(f"{benchmark_type}_perf.csv", "w") as fout: 508 | fout.write("name,batch,seqlen,head,bw\n") 509 | for shape in shape_configs: 510 | if args.all: 511 | for target in available_targets: 512 | perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) 513 | fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') 514 | elif args.compare: 515 | perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) 516 | fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') 517 | fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') 518 | elif args.one: 519 | perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) 520 | fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') -------------------------------------------------------------------------------- /benchmark/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Visualize benchmark results') 9 | parser.add_argument('--file', type=str, default='all_perf.csv', 10 | help='Path to the CSV file with benchmark results (default: all_perf.csv)') 11 | return parser.parse_args() 12 | 13 | args = parse_args() 14 | file_path = args.file 15 | 16 | df = pd.read_csv(file_path) 17 | 18 | names = df['name'].unique() 19 | 20 | for name in names: 21 | subset = df[df['name'] == name] 22 | plt.plot(subset['seqlen'], subset['bw'], label=name) 23 | 24 | plt.title('bandwidth') 25 | plt.xlabel('seqlen') 26 | plt.ylabel('bw (GB/s)') 27 | plt.legend() 28 | 29 | plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png') -------------------------------------------------------------------------------- /csrc/flash_api.cpp: -------------------------------------------------------------------------------- 1 | // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp 2 | /****************************************************************************** 3 | * Copyright (c) 2024, Tri Dao. 4 | ******************************************************************************/ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include "kernels/config.h" 14 | #include "kernels/get_mla_metadata.h" 15 | #include "kernels/mla_combine.h" 16 | #include "kernels/params.h" 17 | #include "kernels/splitkv_mla.h" 18 | 19 | #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") 20 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 21 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 22 | 23 | std::vector 24 | get_mla_metadata( 25 | at::Tensor &seqlens_k, 26 | const int num_heads_per_head_k, 27 | const int num_heads_k 28 | ) { 29 | CHECK_DEVICE(seqlens_k); 30 | TORCH_CHECK(seqlens_k.is_contiguous()); 31 | TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); 32 | 33 | int batch_size = seqlens_k.size(0); 34 | int *seqlens_k_ptr = seqlens_k.data_ptr(); 35 | auto options = seqlens_k.options(); 36 | 37 | auto dprops = at::cuda::getCurrentDeviceProperties(); 38 | int sm_count = dprops->multiProcessorCount; 39 | int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); 40 | 41 | auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); 42 | auto num_splits = torch::empty({batch_size + 1}, options); 43 | int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); 44 | int *num_splits_ptr = num_splits.data_ptr(); 45 | 46 | at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; 47 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 48 | Mla_metadata_params params = {}; 49 | params.seqlens_k_ptr = seqlens_k_ptr; 50 | params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; 51 | params.num_splits_ptr = num_splits_ptr; 52 | params.batch_size = batch_size; 53 | params.block_size_n = Config::PAGE_BLOCK_SIZE; 54 | params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; 55 | params.num_sm_parts = num_sm_parts; 56 | run_get_mla_metadata_kernel(params, stream); 57 | 58 | return {tile_scheduler_metadata, num_splits}; 59 | } 60 | 61 | std::vector 62 | mha_fwd_kvcache_mla( 63 | at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 64 | const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size 65 | const int head_size_v, 66 | const at::Tensor &seqlens_k, // batch_size 67 | const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq 68 | const float softmax_scale, 69 | bool is_causal, 70 | const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize 71 | const at::Tensor &num_splits // batch_size + 1 72 | ) { 73 | // Check the architecture 74 | auto dprops = at::cuda::getCurrentDeviceProperties(); 75 | bool is_sm90 = dprops->major == 9 && dprops->minor == 0; 76 | TORCH_CHECK(is_sm90); 77 | 78 | // Check data types 79 | auto q_dtype = q.dtype(); 80 | TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); 81 | TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); 82 | TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); 83 | TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); 84 | TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); 85 | TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); 86 | 87 | // Check device 88 | CHECK_DEVICE(q); 89 | CHECK_DEVICE(kcache); 90 | CHECK_DEVICE(seqlens_k); 91 | CHECK_DEVICE(block_table); 92 | CHECK_DEVICE(tile_scheduler_metadata); 93 | CHECK_DEVICE(num_splits); 94 | 95 | // Check layout 96 | TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 97 | TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 98 | CHECK_CONTIGUOUS(seqlens_k); 99 | TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); 100 | CHECK_CONTIGUOUS(tile_scheduler_metadata); 101 | CHECK_CONTIGUOUS(num_splits); 102 | 103 | const auto sizes = q.sizes(); 104 | const int batch_size = sizes[0]; 105 | const int seqlen_q_ori = sizes[1]; 106 | const int num_heads_q = sizes[2]; 107 | const int head_size_k = sizes[3]; 108 | TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); 109 | TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); 110 | 111 | const int max_num_blocks_per_seq = block_table.size(1); 112 | const int num_blocks = kcache.size(0); 113 | const int page_block_size = kcache.size(1); 114 | const int num_heads_k = kcache.size(2); 115 | TORCH_CHECK(batch_size > 0, "batch size must be postive"); 116 | TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); 117 | 118 | if (seqlen_q_ori == 1) { is_causal = false; } 119 | 120 | const int num_q_heads_per_hk = num_heads_q / num_heads_k; 121 | const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; 122 | const int num_heads = num_heads_k; 123 | q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) 124 | .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); 125 | 126 | CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); 127 | CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); 128 | CHECK_SHAPE(seqlens_k, batch_size); 129 | CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); 130 | TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); 131 | CHECK_SHAPE(num_splits, batch_size+1); 132 | 133 | at::cuda::CUDAGuard device_guard{(char)q.get_device()}; 134 | 135 | auto opts = q.options(); 136 | at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); 137 | at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); 138 | CHECK_CONTIGUOUS(softmax_lse); 139 | 140 | Flash_fwd_mla_params params = {}; 141 | // Set the sizes. 142 | params.b = batch_size; 143 | params.s_q = seqlen_q_ori; 144 | params.q_seq_per_hk = q_seq_per_hk; 145 | params.seqlens_k_ptr = seqlens_k.data_ptr(); 146 | params.h_q = num_heads_q; 147 | params.h_k = num_heads_k; 148 | params.num_blocks = num_blocks; 149 | params.q_head_per_hk = num_q_heads_per_hk; 150 | params.is_causal = is_causal; 151 | params.d = head_size_k; 152 | params.d_v = head_size_v; 153 | params.scale_softmax = softmax_scale; 154 | params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); 155 | // Set the pointers and strides. 156 | params.q_ptr = q.data_ptr(); 157 | params.k_ptr = kcache.data_ptr(); 158 | params.o_ptr = out.data_ptr(); 159 | params.softmax_lse_ptr = softmax_lse.data_ptr(); 160 | // All stride are in elements, not bytes. 161 | params.q_batch_stride = q.stride(0); 162 | params.k_batch_stride = kcache.stride(0); 163 | params.o_batch_stride = out.stride(0); 164 | params.q_row_stride = q.stride(-3); 165 | params.k_row_stride = kcache.stride(-3); 166 | params.o_row_stride = out.stride(-3); 167 | params.q_head_stride = q.stride(-2); 168 | params.k_head_stride = kcache.stride(-2); 169 | params.o_head_stride = out.stride(-2); 170 | 171 | params.block_table = block_table.data_ptr(); 172 | params.block_table_batch_stride = block_table.stride(0); 173 | params.page_block_size = page_block_size; 174 | 175 | params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); 176 | params.num_sm_parts = tile_scheduler_metadata.size(0); 177 | params.num_splits_ptr = num_splits.data_ptr(); 178 | 179 | const int total_num_splits = batch_size + params.num_sm_parts; 180 | at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); 181 | at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); 182 | CHECK_CONTIGUOUS(softmax_lse_accum); 183 | CHECK_CONTIGUOUS(out_accum); 184 | params.total_num_splits = total_num_splits; 185 | params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); 186 | params.oaccum_ptr = out_accum.data_ptr(); 187 | 188 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 189 | TORCH_CHECK(head_size_k == 576); 190 | if (q_dtype == torch::kBFloat16) { 191 | run_flash_splitkv_mla_kernel(params, stream); 192 | run_flash_mla_combine_kernel(params, stream); 193 | } else if (q_dtype == torch::kHalf) { 194 | #ifdef FLASH_MLA_DISABLE_FP16 195 | TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); 196 | #else 197 | run_flash_splitkv_mla_kernel(params, stream); 198 | run_flash_mla_combine_kernel(params, stream); 199 | #endif 200 | } else { 201 | TORCH_CHECK(false, "Unsupported tensor dtype for query"); 202 | } 203 | 204 | out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) 205 | .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); 206 | softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) 207 | .reshape({batch_size, num_heads_q, seqlen_q_ori}); 208 | 209 | return {out, softmax_lse}; 210 | } 211 | 212 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 213 | m.doc() = "FlashMLA"; 214 | m.def("get_mla_metadata", &get_mla_metadata); 215 | m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); 216 | } 217 | -------------------------------------------------------------------------------- /csrc/kernels/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Config { 4 | 5 | static constexpr int BLOCK_SIZE_M = 64; 6 | static constexpr int PAGE_BLOCK_SIZE = 64; 7 | 8 | static constexpr int HEAD_DIM_K = 576; 9 | static constexpr int HEAD_DIM_V = 512; 10 | 11 | static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5; 12 | 13 | } 14 | -------------------------------------------------------------------------------- /csrc/kernels/get_mla_metadata.cu: -------------------------------------------------------------------------------- 1 | #include "get_mla_metadata.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "utils.h" 7 | 8 | __global__ void __launch_bounds__(32, 1, 1) 9 | get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { 10 | int *seqlens_k_ptr = params.seqlens_k_ptr; 11 | int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; 12 | int *num_splits_ptr = params.num_splits_ptr; 13 | int batch_size = params.batch_size; 14 | int block_size_n = params.block_size_n; 15 | int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; 16 | int num_sm_parts = params.num_sm_parts; 17 | 18 | extern __shared__ int shared_mem[]; 19 | int* num_blocks_shared = shared_mem; // [batch_size] 20 | int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] 21 | 22 | int total_num_blocks = 0; 23 | for (int i = threadIdx.x; i < batch_size; i += 32) { 24 | int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); 25 | total_num_blocks += num_blocks + fixed_overhead_num_blocks; 26 | num_blocks_shared[i] = num_blocks; 27 | } 28 | for (int offset = 16; offset >= 1; offset /= 2) { 29 | total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); 30 | } 31 | __syncwarp(); 32 | 33 | if (threadIdx.x == 0) { 34 | int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks); 35 | 36 | int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; 37 | num_splits_shared[0] = 0; 38 | for (int i = 0; i < num_sm_parts; ++i) { 39 | int tile_scheduler_metadata0[4], tile_scheduler_metadata1; 40 | tile_scheduler_metadata0[0] = now_idx; 41 | tile_scheduler_metadata0[1] = now_block * block_size_n; 42 | tile_scheduler_metadata1 = now_n_split_idx; 43 | int remain_payload = payload; 44 | while (now_idx < batch_size) { 45 | int num_blocks = num_blocks_shared[now_idx]; 46 | int now_remain_blocks = num_blocks - now_block; 47 | if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { 48 | cum_num_splits += now_n_split_idx + 1; 49 | num_splits_shared[now_idx + 1] = cum_num_splits; 50 | remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; 51 | ++now_idx; 52 | now_block = 0; 53 | now_n_split_idx = 0; 54 | } else { 55 | if (remain_payload - fixed_overhead_num_blocks > 0) { 56 | now_block += remain_payload - fixed_overhead_num_blocks; 57 | ++now_n_split_idx; 58 | remain_payload = 0; 59 | } 60 | break; 61 | } 62 | } 63 | tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; 64 | tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; 65 | *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); 66 | tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; 67 | } 68 | FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); 69 | } 70 | __syncwarp(); 71 | 72 | for (int i = threadIdx.x; i <= batch_size; i += 32) { 73 | num_splits_ptr[i] = num_splits_shared[i]; 74 | } 75 | } 76 | 77 | void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) { 78 | int smem_size = sizeof(int) * (params.batch_size*2+1); 79 | CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 80 | get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); 81 | CHECK_CUDA_KERNEL_LAUNCH(); 82 | } 83 | -------------------------------------------------------------------------------- /csrc/kernels/get_mla_metadata.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "params.h" 4 | 5 | void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); 6 | -------------------------------------------------------------------------------- /csrc/kernels/mla_combine.cu: -------------------------------------------------------------------------------- 1 | #include "mla_combine.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "params.h" 9 | #include "utils.h" 10 | #include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V 11 | 12 | using namespace cute; 13 | 14 | template 15 | __global__ void __launch_bounds__(NUM_THREADS) 16 | flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { 17 | // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] 18 | // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result 19 | static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m 20 | const int batch_idx = blockIdx.x; 21 | const int m_block_idx = blockIdx.y; 22 | const int warp_idx = threadIdx.x / 32; 23 | const int lane_idx = threadIdx.x % 32; 24 | 25 | const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); 26 | const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); 27 | const int my_num_splits = end_split_idx - start_split_idx; 28 | FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); 29 | if (my_num_splits == 1) { 30 | return; 31 | } 32 | 33 | const int num_q_seqs = params.q_seq_per_hk * params.h_k; 34 | const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M); 35 | Tensor gLseAccum = make_tensor( 36 | make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), 37 | Shape, Int>{}, 38 | make_stride(num_q_seqs, _1{}) 39 | ); 40 | Tensor gLse = make_tensor( 41 | make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), 42 | Shape>{}, 43 | Stride<_1>{} 44 | ); 45 | 46 | extern __shared__ float smem_buf[]; 47 | Tensor sLseScale = make_tensor( 48 | make_smem_ptr(smem_buf), 49 | Shape, Int>{}, 50 | Stride, _1>{} // +1 to avoid bank conflict 51 | ); 52 | 53 | // Wait for the previous kernel (the MLA kernel) to finish 54 | cudaGridDependencySynchronize(); 55 | 56 | // Read gLseAccum into sLseScale 57 | { 58 | #pragma unroll 4 59 | for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) { 60 | int split_idx = elem_idx / BLOCK_SIZE_M; 61 | int seq_idx = elem_idx % BLOCK_SIZE_M; 62 | sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY; 63 | } 64 | __syncthreads(); 65 | } 66 | 67 | if (warp_idx >= num_cur_valid_q_seqs) 68 | return; 69 | 70 | // Warp #i gathers LseAccum for seq #i 71 | { 72 | constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); 73 | float local_lse[NUM_LSE_PER_THREAD]; 74 | CUTLASS_PRAGMA_UNROLL 75 | for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { 76 | const int split_idx = i*32 + lane_idx; 77 | local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY; 78 | } 79 | 80 | float max_lse = -INFINITY; 81 | CUTLASS_PRAGMA_UNROLL 82 | for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) 83 | max_lse = max(max_lse, local_lse[i]); 84 | CUTLASS_PRAGMA_UNROLL 85 | for (int offset = 16; offset >= 1; offset /= 2) 86 | max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); 87 | max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf 88 | 89 | float sum_lse = 0; 90 | CUTLASS_PRAGMA_UNROLL 91 | for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) 92 | sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); 93 | CUTLASS_PRAGMA_UNROLL 94 | for (int offset = 16; offset >= 1; offset /= 2) 95 | sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); 96 | 97 | float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse; 98 | if (lane_idx == 0) 99 | gLse(warp_idx) = global_lse / (float)M_LOG2E; 100 | 101 | CUTLASS_PRAGMA_UNROLL 102 | for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { 103 | const int split_idx = i*32 + lane_idx; 104 | if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse); 105 | } 106 | } 107 | 108 | __syncwarp(); 109 | 110 | // Warp #i accumulates activation for seq #i 111 | { 112 | const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V; 113 | Tensor gOaccum = make_tensor( 114 | make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), 115 | Shape, Int>{}, 116 | make_stride(num_q_seqs*HEAD_DIM_V, _1{}) 117 | ); 118 | 119 | static_assert(HEAD_DIM_V % 32 == 0); 120 | constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32; 121 | float result[ELEMS_PER_THREAD]; 122 | CUTLASS_PRAGMA_UNROLL 123 | for (int i = 0; i < ELEMS_PER_THREAD; ++i) 124 | result[i] = 0.0f; 125 | 126 | #pragma unroll 2 127 | for (int split = 0; split < my_num_splits; ++split) { 128 | float lse_scale = sLseScale(warp_idx, split); 129 | if (lse_scale != 0.f) { 130 | CUTLASS_PRAGMA_UNROLL 131 | for (int i = 0; i < ELEMS_PER_THREAD; ++i) { 132 | result[i] += lse_scale * gOaccum(split, lane_idx + i*32); 133 | } 134 | } 135 | } 136 | 137 | cudaTriggerProgrammaticLaunchCompletion(); 138 | 139 | const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx; 140 | const int k_head_idx = q_seq_idx / params.q_seq_per_hk; 141 | auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride; 142 | Tensor gO = make_tensor( 143 | make_gmem_ptr(o_ptr), 144 | Shape>{}, 145 | Stride<_1>{} 146 | ); 147 | 148 | CUTLASS_PRAGMA_UNROLL 149 | for (int i = 0; i < ELEMS_PER_THREAD; ++i) 150 | gO(lane_idx+i*32) = (ElementT)result[i]; 151 | } 152 | } 153 | 154 | 155 | #define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ 156 | [&] { \ 157 | if (NUM_SPLITS <= 32) { \ 158 | constexpr static int NAME = 32; \ 159 | return __VA_ARGS__(); \ 160 | } else if (NUM_SPLITS <= 64) { \ 161 | constexpr static int NAME = 64; \ 162 | return __VA_ARGS__(); \ 163 | } else if (NUM_SPLITS <= 96) { \ 164 | constexpr static int NAME = 96; \ 165 | return __VA_ARGS__(); \ 166 | } else if (NUM_SPLITS <= 128) { \ 167 | constexpr static int NAME = 128; \ 168 | return __VA_ARGS__(); \ 169 | } else if (NUM_SPLITS <= 160) { \ 170 | constexpr static int NAME = 160; \ 171 | return __VA_ARGS__(); \ 172 | } else { \ 173 | FLASH_ASSERT(false); \ 174 | } \ 175 | }() 176 | 177 | 178 | template 179 | void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { 180 | MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { 181 | constexpr int BLOCK_SIZE_M = 8; 182 | constexpr int NUM_THREADS = BLOCK_SIZE_M*32; 183 | constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); 184 | auto combine_kernel = &flash_fwd_mla_combine_kernel; 185 | CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 186 | // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) 187 | cudaLaunchAttribute attribute[1]; 188 | attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; 189 | attribute[0].val.programmaticStreamSerializationAllowed = 1; 190 | cudaLaunchConfig_t combine_kernel_config = { 191 | dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), 192 | dim3(NUM_THREADS, 1, 1), 193 | smem_size, 194 | stream, 195 | attribute, 196 | 1 197 | }; 198 | cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params); 199 | }); 200 | CHECK_CUDA_KERNEL_LAUNCH(); 201 | } 202 | 203 | template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 204 | 205 | #ifndef FLASH_MLA_DISABLE_FP16 206 | template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 207 | #endif -------------------------------------------------------------------------------- /csrc/kernels/mla_combine.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "params.h" 4 | 5 | template 6 | void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 7 | -------------------------------------------------------------------------------- /csrc/kernels/params.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | //////////////////////////////////////////////////////////////////////////////////////////////////// 4 | 5 | struct Flash_fwd_mla_params { 6 | using index_t = int64_t; 7 | 8 | int b; // batch size 9 | int s_q; 10 | int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q 11 | int d, d_v; // K/V dimension 12 | int h_q, h_k; // The number of Q/K heads 13 | int num_blocks; // Number of blocks in total 14 | int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k 15 | bool is_causal; 16 | float scale_softmax, scale_softmax_log2; 17 | 18 | void *__restrict__ q_ptr; 19 | void *__restrict__ k_ptr; 20 | void *__restrict__ o_ptr; 21 | void *__restrict__ softmax_lse_ptr; 22 | 23 | index_t q_batch_stride; 24 | index_t k_batch_stride; 25 | index_t o_batch_stride; 26 | index_t q_row_stride; 27 | index_t k_row_stride; 28 | index_t o_row_stride; 29 | index_t q_head_stride; 30 | index_t k_head_stride; 31 | index_t o_head_stride; 32 | 33 | int *__restrict__ block_table; 34 | index_t block_table_batch_stride; 35 | int page_block_size; 36 | int *__restrict__ seqlens_k_ptr; 37 | 38 | int *__restrict__ tile_scheduler_metadata_ptr; 39 | int num_sm_parts; 40 | int *__restrict__ num_splits_ptr; 41 | 42 | int total_num_splits; 43 | void *__restrict__ softmax_lseaccum_ptr; 44 | void *__restrict__ oaccum_ptr; 45 | }; 46 | 47 | static constexpr int TileSchedulerMetaDataSize = 8; 48 | // [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] 49 | 50 | struct Mla_metadata_params { 51 | int *__restrict__ seqlens_k_ptr; 52 | int *__restrict__ tile_scheduler_metadata_ptr; 53 | int *__restrict__ num_splits_ptr; 54 | int batch_size; 55 | int block_size_n; 56 | int fixed_overhead_num_blocks; 57 | int num_sm_parts; 58 | }; 59 | -------------------------------------------------------------------------------- /csrc/kernels/splitkv_mla.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "params.h" 4 | #include "utils.h" 5 | #include "config.h" 6 | #include "traits.h" 7 | 8 | using namespace cute; 9 | using cutlass::arch::NamedBarrier; 10 | 11 | // Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking 12 | // The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) 13 | // so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM 14 | static constexpr float MAX_INIT_VAL_SM = -1e30f; 15 | static constexpr float MAX_INIT_VAL = -1e33f; 16 | 17 | 18 | __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { 19 | // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx 20 | // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a 21 | int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); 22 | return row_idx; 23 | } 24 | 25 | // Launch TMA copy for a range of KV tile 26 | // A tile has a shape of PAGE_BLOCK_SIZE (64) x 64 27 | template< 28 | int START_HEAD_DIM_TILE_IDX, 29 | int END_HEAD_DIM_TILE_IDX, 30 | typename TMA_K_OneTile, 31 | typename Engine0, typename Layout0, 32 | typename Engine1, typename Layout1 33 | > 34 | __forceinline__ __device__ void launch_kv_tiles_copy_tma( 35 | Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) 36 | Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled 37 | TMA_K_OneTile &tma_K, 38 | TMABarrier* barriers_K, 39 | int idx_in_warpgroup 40 | ) { 41 | if (idx_in_warpgroup == 0) { 42 | auto thr_tma = tma_K.get_slice(_0{}); 43 | Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); 44 | Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{}); 45 | cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV); 46 | if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { 47 | launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup); 48 | } 49 | } 50 | } 51 | 52 | // Prefetch some KV tiles 53 | // Currently this is not used because it leads to performance degradation 54 | template< 55 | int START_HEAD_DIM_TILE_IDX, 56 | int END_HEAD_DIM_TILE_IDX, 57 | typename TMA_K_OneTile, 58 | typename Engine0, typename Layout0 59 | > 60 | __forceinline__ __device__ void prefetch_kv_tiles( 61 | Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) 62 | TMA_K_OneTile &tma_K, 63 | int idx_in_warpgroup 64 | ) { 65 | if (idx_in_warpgroup == 0) { 66 | auto thr_tma = tma_K.get_slice(_0{}); 67 | Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); 68 | cute::prefetch(tma_K, cur_gKV); 69 | if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { 70 | prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup); 71 | } 72 | } 73 | } 74 | 75 | // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h 76 | // * Copyright (c) 2024, Tri Dao. 77 | template 78 | __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { 79 | constexpr bool Is_RS = !cute::is_base_of::value; 80 | // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const 81 | if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } 82 | warpgroup_fence_operand(tCrC); 83 | if constexpr (arrive) { 84 | warpgroup_arrive(); 85 | } 86 | if constexpr (zero_init) { 87 | tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; 88 | // Unroll the K mode manually to set scale D to 1 89 | CUTLASS_PRAGMA_UNROLL 90 | for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { 91 | cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); 92 | tiled_mma.accumulate_ = GMMA::ScaleOut::One; 93 | } 94 | } else { 95 | // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); 96 | // Unroll the K mode manually to set scale D to 1 97 | CUTLASS_PRAGMA_UNROLL 98 | for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { 99 | cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); 100 | tiled_mma.accumulate_ = GMMA::ScaleOut::One; 101 | } 102 | } 103 | if constexpr (commit) { 104 | warpgroup_commit_batch(); 105 | } 106 | if constexpr (wg_wait >= 0) { warpgroup_wait(); } 107 | warpgroup_fence_operand(tCrC); 108 | if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } 109 | } 110 | 111 | 112 | // Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64) 113 | // The Q-tile should be in shared memory 114 | template< 115 | typename TiledMMA, 116 | typename Engine0, typename Layout0, 117 | typename Engine1, typename Layout1, 118 | typename Engine2, typename Layout2 119 | > 120 | __forceinline__ __device__ void qkt_gemm_one_tile_sQ( 121 | TiledMMA &tiled_mma, 122 | Tensor const &thr_mma_sQ_tile, // (MMA, 1, 4) 123 | Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) 124 | Tensor &rP, // ((2, 2, 8), 1, 1) 125 | TMABarrier* barrier, 126 | bool &cur_phase, 127 | int idx_in_warpgroup 128 | ) { 129 | if (idx_in_warpgroup == 0) { 130 | barrier->arrive_and_expect_tx(64*64*2); 131 | } 132 | barrier->wait(cur_phase ? 1 : 0); 133 | 134 | warpgroup_fence_operand(rP); 135 | warpgroup_arrive(); 136 | cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); 137 | tiled_mma.accumulate_ = GMMA::ScaleOut::One; 138 | cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); 139 | cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); 140 | cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); 141 | warpgroup_commit_batch(); 142 | warpgroup_fence_operand(rP); 143 | } 144 | 145 | template< 146 | typename TiledMMA, 147 | typename Engine0, typename Layout0, 148 | typename Engine1, typename Layout1, 149 | typename Engine2, typename Layout2 150 | > 151 | __forceinline__ __device__ void qkt_gemm_one_tile_rQ( 152 | TiledMMA &tiled_mma, 153 | Tensor const &thr_mma_rQ_tile, // (MMA, 1, 4) 154 | Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) 155 | Tensor &rP, // ((2, 2, 8), 1, 1) 156 | TMABarrier* barrier, 157 | bool &cur_phase, 158 | int idx_in_warpgroup 159 | ) { 160 | if (idx_in_warpgroup == 0) { 161 | barrier->arrive_and_expect_tx(64*64*2); 162 | } 163 | barrier->wait(cur_phase ? 1 : 0); 164 | 165 | warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); 166 | warpgroup_fence_operand(rP); 167 | warpgroup_arrive(); 168 | cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); 169 | tiled_mma.accumulate_ = GMMA::ScaleOut::One; 170 | cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); 171 | cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); 172 | cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); 173 | warpgroup_commit_batch(); 174 | warpgroup_fence_operand(rP); 175 | warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); 176 | } 177 | 178 | // Pipelined TMA wait and Q K^T gemm 179 | // In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows: 180 | // - Wait for the 0-th tile to be ready using `barrier.wait()` 181 | // - Compute Q K^T for the 0-th tile 182 | // - Wait for the 1-st tile to be ready 183 | // - Compute Q K^T for the 1-st tile 184 | // ... 185 | // This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation 186 | template< 187 | typename T, // Traits 188 | int PHASE_IDX, // See comments in the code 189 | typename Engine0, typename Layout0, 190 | typename Engine1, typename Layout1, 191 | typename Engine2, typename Layout2, 192 | typename Engine3, typename Layout3 193 | > 194 | __forceinline__ __device__ void warpgroup_cooperative_qkt_gemm( 195 | Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) 196 | Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) 197 | Tensor &rP, // ((2, 2, 8), 1, 1) 198 | Tensor &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1 199 | TMABarrier* barriers, 200 | bool &cur_phase, 201 | int idx_in_warpgroup 202 | ) { 203 | Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9) 204 | Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9) 205 | TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){}; 206 | ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup); 207 | Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9) 208 | Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9) 209 | TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){}; 210 | 211 | #define QKT_GEMM_ONE_TILE(TILE_IDX) \ 212 | if constexpr(TILE_IDX != 8) { \ 213 | qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ 214 | } else { \ 215 | qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ 216 | } 217 | 218 | if constexpr (PHASE_IDX == 0) { 219 | // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles 220 | tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; 221 | tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; 222 | QKT_GEMM_ONE_TILE(0); 223 | QKT_GEMM_ONE_TILE(1); 224 | QKT_GEMM_ONE_TILE(2); 225 | QKT_GEMM_ONE_TILE(3); 226 | } else if constexpr (PHASE_IDX == 1) { 227 | // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles 228 | tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; 229 | tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; 230 | QKT_GEMM_ONE_TILE(4); 231 | QKT_GEMM_ONE_TILE(5); 232 | QKT_GEMM_ONE_TILE(6); 233 | QKT_GEMM_ONE_TILE(7); 234 | QKT_GEMM_ONE_TILE(8); 235 | QKT_GEMM_ONE_TILE(0); 236 | QKT_GEMM_ONE_TILE(1); 237 | QKT_GEMM_ONE_TILE(2); 238 | QKT_GEMM_ONE_TILE(3); 239 | cur_phase ^= 1; 240 | } else { 241 | // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles 242 | static_assert(PHASE_IDX == 2); 243 | tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One; 244 | tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; 245 | QKT_GEMM_ONE_TILE(4); 246 | QKT_GEMM_ONE_TILE(5); 247 | QKT_GEMM_ONE_TILE(6); 248 | QKT_GEMM_ONE_TILE(7); 249 | QKT_GEMM_ONE_TILE(8); 250 | cur_phase ^= 1; 251 | } 252 | } 253 | 254 | 255 | template< 256 | typename T, 257 | typename Engine0, typename Layout0, 258 | typename Engine1, typename Layout1, 259 | typename Engine2, typename Layout2 260 | > 261 | __forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline( 262 | Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) 263 | Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K) 264 | Tensor &rP, // ((2, 2, 8), 1, 1) 265 | int idx_in_warpgroup 266 | ) { 267 | TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){}; 268 | ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); 269 | Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36) 270 | Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36) 271 | gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP); 272 | } 273 | 274 | 275 | // Compute O += PV, where P resides in register 276 | template< 277 | typename T, 278 | typename Engine0, typename Layout0, 279 | typename Engine1, typename Layout1, 280 | typename Engine2, typename Layout2 281 | > 282 | __forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP( 283 | Tensor &rP, // ((2, 2, 8), 1, 1), fragment A layout 284 | Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) 285 | Tensor &rO, // ((2, 2, 32), 1, 1) 286 | int idx_in_warpgroup 287 | ) { 288 | TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){}; 289 | ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); 290 | Tensor rP_retiled = make_tensor(rP.data(), Layout< 291 | Shape, _1, _4>, 292 | Stride, _0, _8> 293 | >{}); 294 | Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) 295 | gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO); 296 | } 297 | 298 | 299 | // Compute O += PV, where P resides in shared memory 300 | template< 301 | typename T, 302 | typename Engine0, typename Layout0, 303 | typename Engine1, typename Layout1, 304 | typename Engine2, typename Layout2 305 | > 306 | __forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP( 307 | Tensor &sP, 308 | Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) 309 | Tensor &rO, // ((2, 2, 32), 1, 1) 310 | int idx_in_warpgroup 311 | ) { 312 | TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){}; 313 | ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); 314 | Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP); 315 | Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) 316 | gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO); 317 | } 318 | 319 | 320 | template< 321 | typename T, 322 | bool DO_OOB_FILLING, 323 | typename Engine0, typename Layout0, 324 | typename Engine1, typename Layout1, 325 | typename Engine2, typename Layout2, 326 | typename Engine3, typename Layout3, 327 | typename Engine4, typename Layout4 328 | > 329 | __forceinline__ __device__ void wg0_bunch_0( 330 | Tensor &rPb, // ((2, 2, 8), 1, 1) 331 | Tensor &rP0, // ((2, 2, 8), 1, 1) 332 | Tensor &rO0, // ((2, 2, 32), 1, 1) 333 | Tensor &sScale0, // (BLOCK_SIZE_M) 334 | Tensor &sM, // (BLOCK_SIZE_M) 335 | float rL[2], 336 | int rRightBorderForQSeq[2], 337 | float scale_softmax_log2, 338 | int start_token_idx, 339 | int idx_in_warpgroup 340 | ) { 341 | // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png) 342 | CUTLASS_PRAGMA_UNROLL 343 | for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { 344 | int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); 345 | 346 | // Mask, and get row-wise max 347 | float cur_max = MAX_INIT_VAL; 348 | CUTLASS_PRAGMA_UNROLL 349 | for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { 350 | if constexpr (DO_OOB_FILLING) { 351 | int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; 352 | rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL; 353 | rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL; 354 | } 355 | cur_max = max(cur_max, max(rP0(i), rP0(i+1))); 356 | } 357 | cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); 358 | cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); 359 | 360 | // Update sM and sL 361 | cur_max *= scale_softmax_log2; 362 | float new_max = max(sM(row_idx), cur_max); 363 | float scale_for_old = exp2f(sM(row_idx) - new_max); 364 | __syncwarp(); // Make sure all reads have finished before updating sM 365 | if (idx_in_warpgroup%4 == 0) { 366 | sScale0(row_idx) = scale_for_old; 367 | sM(row_idx) = new_max; 368 | } 369 | 370 | // Scale-O 371 | CUTLASS_PRAGMA_UNROLL 372 | for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { 373 | rO0(i) *= scale_for_old; 374 | rO0(i+1) *= scale_for_old; 375 | } 376 | 377 | // Scale, exp, and get row-wise expsum 378 | float cur_sum = 0; 379 | CUTLASS_PRAGMA_UNROLL 380 | for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { 381 | rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max); 382 | rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max); 383 | rPb(i) = (typename T::InputT)rP0(i); 384 | rPb(i+1) = (typename T::InputT)rP0(i+1); 385 | cur_sum += rP0(i) + rP0(i+1); 386 | } 387 | rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; 388 | } 389 | } 390 | 391 | 392 | template< 393 | typename T, 394 | bool IS_BLK0_LAST, 395 | bool IS_BLK1_LAST, 396 | bool IS_BLK2_LAST, 397 | typename Engine0, typename Layout0, 398 | typename Engine1, typename Layout1, 399 | typename Engine2, typename Layout2, 400 | typename Engine3, typename Layout3, 401 | typename Engine4, typename Layout4, 402 | typename Engine5, typename Layout5 403 | > 404 | __forceinline__ __device__ void wg1_bunch_0( 405 | Tensor &rP1b, // ((2, 2, 8), 1, 1) 406 | Tensor &sScale1, // (BLOCK_SIZE_M) 407 | Tensor &rO1, // ((2, 2, 32), 1, 1) 408 | Tensor &sM, // (BLOCK_SIZE_M) 409 | float rL[2], 410 | int rRightBorderForQSeq[2], 411 | Tensor const &sScale0, // (BLOCK_SIZE_M) 412 | Tensor &rP1, // ((2, 2, 8), 1, 1) 413 | float scale_softmax_log2, 414 | int start_token_idx, 415 | int idx_in_warpgroup 416 | ) { 417 | CUTLASS_PRAGMA_UNROLL 418 | for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { 419 | int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); 420 | 421 | // Mask, and get row-wise max 422 | float cur_max = MAX_INIT_VAL; 423 | CUTLASS_PRAGMA_UNROLL 424 | for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { 425 | if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) { 426 | // Need to apply the mask when either this block is the last one, or 427 | // the next block is the last one (because of the causal mask) 428 | int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; 429 | rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL; 430 | rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL; 431 | } else if constexpr (IS_BLK0_LAST) { 432 | rP1(i) = rP1(i+1) = MAX_INIT_VAL; 433 | } 434 | cur_max = max(cur_max, max(rP1(i), rP1(i+1))); 435 | } 436 | cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); 437 | cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); 438 | cur_max *= scale_softmax_log2; 439 | 440 | float old_max = sM(row_idx); 441 | float new_max = max(old_max, cur_max); 442 | float scale_for_old = exp2f(old_max - new_max); 443 | __syncwarp(); 444 | if (idx_in_warpgroup%4 == 0) { 445 | sM(row_idx) = new_max; 446 | sScale1(row_idx) = scale_for_old; 447 | } 448 | 449 | // Scale, exp, and get row-wise expsum 450 | float cur_sum = 0; 451 | if constexpr (!IS_BLK0_LAST) { 452 | CUTLASS_PRAGMA_UNROLL 453 | for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { 454 | rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max); 455 | rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max); 456 | rP1b(i) = (typename T::InputT)rP1(i); 457 | rP1b(i+1) = (typename T::InputT)rP1(i+1); 458 | cur_sum += rP1(i) + rP1(i+1); 459 | } 460 | } 461 | 462 | // Scale O 463 | float cur_scale_for_o1 = scale_for_old * sScale0(row_idx); 464 | CUTLASS_PRAGMA_UNROLL 465 | for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) { 466 | rO1(i) *= cur_scale_for_o1; 467 | rO1(i+1) *= cur_scale_for_o1; 468 | } 469 | 470 | // Update rL 471 | rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum; 472 | } 473 | } 474 | 475 | 476 | // Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction 477 | template< 478 | typename T, 479 | typename Engine0, typename Layout0, 480 | typename Engine1, typename Layout1 481 | > 482 | __forceinline__ __device__ void save_rPb_to_sP( 483 | Tensor &rPb, 484 | Tensor &sP, 485 | int idx_in_warpgroup 486 | ) { 487 | auto r2s_copy = make_tiled_copy_C( 488 | Copy_Atom{}, 489 | (typename T::TiledMMA_QK_sQ){} 490 | ); 491 | ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); 492 | Tensor thr_copy_rPb = thr_copy.retile_S(rPb); 493 | Tensor thr_copy_sP = thr_copy.partition_D(sP); 494 | cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); 495 | } 496 | 497 | 498 | // Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction 499 | template< 500 | typename T, 501 | typename Engine0, typename Layout0, 502 | typename Engine1, typename Layout1 503 | > 504 | __forceinline__ __device__ void retrieve_rP_from_sP( 505 | Tensor &rPb, 506 | Tensor const &sP, 507 | int idx_in_warpgroup 508 | ) { 509 | TiledCopy s2r_copy = make_tiled_copy_A( 510 | Copy_Atom{}, 511 | (typename T::TiledMMA_PV_LocalP){} 512 | ); 513 | ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); 514 | Tensor thr_copy_sP = thr_copy.partition_S(sP); 515 | Tensor thr_copy_rPb = thr_copy.retile_D(rPb); 516 | cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); 517 | } 518 | 519 | 520 | // Rescale rP0 and save the result to rPb 521 | template< 522 | typename T, 523 | typename Engine0, typename Layout0, 524 | typename Engine1, typename Layout1, 525 | typename Engine2, typename Layout2 526 | > 527 | __forceinline__ __device__ void wg0_scale_rP0( 528 | Tensor const &sScale1, // (BLOCK_M) 529 | Tensor const &rP0, // ((2, 2, 8), 1, 1) 530 | Tensor &rPb, // ((2, 2, 8), 1, 1) 531 | int idx_in_warpgroup 532 | ) { 533 | CUTLASS_PRAGMA_UNROLL 534 | for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { 535 | int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); 536 | float scale_factor = sScale1(row_idx); 537 | CUTLASS_PRAGMA_UNROLL 538 | for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { 539 | rPb(i) = (typename T::InputT)(rP0(i)*scale_factor); 540 | rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor); 541 | } 542 | } 543 | } 544 | 545 | 546 | // Rescale rO0 according to sScale1 547 | template< 548 | typename Engine0, typename Layout0, 549 | typename Engine1, typename Layout1 550 | > 551 | __forceinline__ __device__ void wg0_rescale_rO0( 552 | Tensor &rO0, 553 | Tensor &sScale1, 554 | float rL[2], 555 | int idx_in_warpgroup 556 | ) { 557 | CUTLASS_PRAGMA_UNROLL 558 | for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { 559 | int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); 560 | float scale_factor = sScale1(row_idx); 561 | CUTLASS_PRAGMA_UNROLL 562 | for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { 563 | rO0(i) *= scale_factor; 564 | rO0(i+1) *= scale_factor; 565 | } 566 | rL[local_row_idx] *= scale_factor; 567 | } 568 | } 569 | 570 | 571 | // Fill out-of-bound V with 0.0 572 | // We must fill it since it may contain NaN, which may propagate to the final result 573 | template< 574 | typename T, 575 | typename Engine0, typename Layout0 576 | > 577 | __forceinline__ __device__ void fill_oob_V( 578 | Tensor &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} ); 579 | int valid_window_size, 580 | int idx_in_warpgroup 581 | ) { 582 | Tensor sV_int64 = make_tensor( 583 | make_smem_ptr((int64_t*)(sV.data().get().get())), 584 | tile_to_shape( 585 | GMMA::Layout_MN_SW128_Atom{}, 586 | Shape, Int>{}, 587 | LayoutRight{} 588 | ) 589 | ); 590 | valid_window_size = max(valid_window_size, 0); 591 | int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds 592 | for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) { 593 | sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0; 594 | } 595 | } 596 | 597 | 598 | // Store O / OAccum 599 | template< 600 | typename T, 601 | bool IS_NO_SPLIT, 602 | typename TMAParams, 603 | typename Engine0, typename Layout0, 604 | typename Engine1, typename Layout1 605 | > 606 | __forceinline__ __device__ void store_o( 607 | Tensor &rO, // ((2, 2, 32), 1, 1) 608 | Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) 609 | float rL[2], 610 | char* sO_addr, 611 | TMAParams &tma_params, 612 | int batch_idx, 613 | int k_head_idx, 614 | int m_block_idx, 615 | int num_valid_seq_q, 616 | int warpgroup_idx, 617 | int idx_in_warpgroup 618 | ) { 619 | using InputT = typename T::InputT; 620 | if constexpr (IS_NO_SPLIT) { 621 | // Should convert the output to bfloat16 / float16, and save it to O 622 | Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape( 623 | GMMA::Layout_K_SW128_Atom{}, 624 | Shape, Int>{} 625 | )); 626 | 627 | Tensor rOb = make_tensor_like(rO); 628 | CUTLASS_PRAGMA_UNROLL 629 | for (int idx = 0; idx < size(rO); ++idx) { 630 | rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]); 631 | } 632 | 633 | Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); 634 | TiledCopy r2s_tiled_copy = make_tiled_copy_C( 635 | Copy_Atom{}, 636 | (typename T::TiledMMA_PV_LocalP){} 637 | ); 638 | ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); 639 | Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); 640 | Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); 641 | cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); 642 | cutlass::arch::fence_view_async_shared(); 643 | 644 | __syncthreads(); 645 | 646 | if (threadIdx.x == 0) { 647 | Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) 648 | auto thr_tma = tma_params.tma_O.get_slice(_0{}); 649 | Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{}); 650 | cute::copy( 651 | tma_params.tma_O, 652 | thr_tma.partition_S(sOutputBuf), 653 | thr_tma.partition_D(my_tma_gO) 654 | ); 655 | cute::tma_store_arrive(); 656 | } 657 | } else { 658 | // Should save the result to OAccum 659 | Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout< 660 | Shape<_64, _512>, 661 | Stride, _1> // We use stride = 520 here to avoid bank conflict 662 | >{}); 663 | 664 | CUTLASS_PRAGMA_UNROLL 665 | for (int idx = 0; idx < size(rO); idx += 2) { 666 | int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); 667 | int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; 668 | *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 { 669 | rO(idx) / rL[idx%4 >= 2], 670 | rO(idx+1) / rL[idx%4 >= 2], 671 | }; 672 | } 673 | cutlass::arch::fence_view_async_shared(); 674 | 675 | __syncthreads(); 676 | 677 | int row = threadIdx.x; 678 | if (row < num_valid_seq_q) { 679 | SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float)); 680 | cute::tma_store_arrive(); 681 | } 682 | } 683 | } 684 | 685 | template< 686 | typename T, 687 | typename TmaParams, typename Tensor0 688 | > 689 | __forceinline__ __device__ void launch_q_copy( 690 | TmaParams const &tma_params, 691 | int batch_idx, 692 | int m_block_idx, 693 | int k_head_idx, 694 | Tensor0 &sQ, 695 | TMABarrier* barrier_Q 696 | ) { 697 | if (threadIdx.x == 0) { 698 | Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) 699 | auto thr_tma = tma_params.tma_Q.get_slice(_0{}); 700 | Tensor my_tma_gQ = flat_divide(tma_gQ, Shape, Int>{})(_, _, m_block_idx, _0{}); 701 | cute::copy( 702 | tma_params.tma_Q.with(reinterpret_cast(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), 703 | thr_tma.partition_S(my_tma_gQ), 704 | thr_tma.partition_D(sQ) 705 | ); 706 | barrier_Q->arrive_and_expect_tx(64*576*2); 707 | } 708 | } 709 | 710 | template< 711 | typename T, 712 | bool IS_R, 713 | typename Engine0, typename Layout0 714 | > 715 | __forceinline__ __device__ auto get_half_V( 716 | Tensor &sK 717 | ) { 718 | Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){}); 719 | return flat_divide(sV, Shape, Int>{})(_, _, Int<(int)IS_R>{}, _0{}); 720 | } 721 | 722 | template< 723 | typename T, 724 | bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ... 725 | bool IS_BLK1_LAST, 726 | typename TMAParams, 727 | typename Engine0, typename Layout0, 728 | typename Engine1, typename Layout1, 729 | typename Engine2, typename Layout2, 730 | typename Engine3, typename Layout3, 731 | typename Engine4, typename Layout4, 732 | typename Engine5, typename Layout5, 733 | typename Engine6, typename Layout6, 734 | typename Engine7, typename Layout7, 735 | typename Engine8, typename Layout8, 736 | typename Engine9, typename Layout9, 737 | typename Engine10, typename Layout10, 738 | typename Engine11, typename Layout11 739 | > 740 | __forceinline__ __device__ void wg0_subroutine( 741 | Tensor &tma_gK, 742 | Tensor &sQ, 743 | Tensor &sK0, 744 | Tensor &sK1, 745 | Tensor &sP0, 746 | Tensor &sP1, 747 | Tensor &sM, 748 | Tensor &sScale0, 749 | Tensor &sScale1, 750 | Tensor &rQ8, 751 | Tensor &rP0, 752 | Tensor &rO0, 753 | float rL[2], 754 | int rRightBorderForQSeq[2], 755 | TMABarrier barriers_K0[9], 756 | TMABarrier barriers_K1[9], 757 | bool &cur_phase_K0, 758 | const TMAParams &tma_params, 759 | const Flash_fwd_mla_params ¶ms, 760 | int* block_table_ptr, 761 | int seqlen_k, 762 | int block_idx, 763 | int end_block_idx, 764 | int idx_in_warpgroup 765 | ) { 766 | int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; 767 | #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx))) 768 | int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); 769 | int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); 770 | 771 | Tensor sV0L = get_half_V(sK0); 772 | Tensor sV1L = get_half_V(sK1); 773 | 774 | Tensor rPb = make_tensor(Shape, _1, _4>{}); 775 | // Calc P0 = softmax(P0) 776 | wg0_bunch_0(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup); 777 | NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready); 778 | 779 | // Issue rO0 += rPb @ sV0L 780 | if constexpr (IS_BLK0_LAST) { 781 | fill_oob_V(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup); 782 | cutlass::arch::fence_view_async_shared(); 783 | } 784 | warpgroup_cooperative_pv_gemm_localP(rPb, sV0L, rO0, idx_in_warpgroup); 785 | 786 | // Wait for rO0, launch TMA for the next V0L 787 | cute::warpgroup_wait<0>(); 788 | 789 | // Wait for warpgroup 1, rescale P0, notify warpgroup 1 790 | NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready); 791 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { 792 | // Put it here seems to be faster, don't know why 793 | launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); 794 | } 795 | wg0_scale_rP0(sScale1, rP0, rPb, idx_in_warpgroup); 796 | save_rPb_to_sP(rPb, sP0, idx_in_warpgroup); 797 | cutlass::arch::fence_view_async_shared(); 798 | NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready); 799 | 800 | // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L 801 | if constexpr (!IS_BLK0_LAST) { 802 | if constexpr (IS_BLK1_LAST) { 803 | fill_oob_V(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); 804 | cutlass::arch::fence_view_async_shared(); 805 | } 806 | NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); 807 | wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup); 808 | warpgroup_cooperative_pv_gemm_remoteP(sP1, sV1L, rO0, idx_in_warpgroup); 809 | } 810 | 811 | // Issue P0 = Q @ K0^T 812 | // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline. 813 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { 814 | warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); 815 | } 816 | 817 | // Wait for rO0 += rPb @ sV1L, launch TMA 818 | if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) { 819 | cute::warpgroup_wait<4>(); 820 | launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); 821 | } 822 | 823 | // Issue P0 = Q @ K0^T 824 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { 825 | warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); 826 | } 827 | 828 | // Wait for P0 = Q @ K0^T 829 | cute::warpgroup_wait<0>(); 830 | } 831 | 832 | 833 | template< 834 | typename T, 835 | bool IS_BLK0_LAST, 836 | bool IS_BLK1_LAST, 837 | bool IS_BLK2_LAST, 838 | typename TMAParams, 839 | typename Engine0, typename Layout0, 840 | typename Engine1, typename Layout1, 841 | typename Engine2, typename Layout2, 842 | typename Engine3, typename Layout3, 843 | typename Engine4, typename Layout4, 844 | typename Engine5, typename Layout5, 845 | typename Engine6, typename Layout6, 846 | typename Engine7, typename Layout7, 847 | typename Engine8, typename Layout8, 848 | typename Engine9, typename Layout9, 849 | typename Engine10, typename Layout10, 850 | typename Engine11, typename Layout11 851 | > 852 | __forceinline__ __device__ void wg1_subroutine( 853 | Tensor &tma_gK, 854 | Tensor &sQ, 855 | Tensor &sK0, 856 | Tensor &sK1, 857 | Tensor &sP0, 858 | Tensor &sP1, 859 | Tensor &sM, 860 | Tensor &sScale0, 861 | Tensor &sScale1, 862 | Tensor &rQ8, 863 | Tensor &rP1, 864 | Tensor &rO1, 865 | float rL[2], 866 | int rRightBorderForQSeq[2], 867 | TMABarrier barriers_K0[9], 868 | TMABarrier barriers_K1[9], 869 | bool &cur_phase_K1, 870 | const TMAParams &tma_params, 871 | const Flash_fwd_mla_params ¶ms, 872 | int* block_table_ptr, 873 | int seqlen_k, 874 | int block_idx, 875 | int end_block_idx, 876 | int idx_in_warpgroup 877 | ) { 878 | int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; 879 | int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); 880 | int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); 881 | 882 | Tensor rP1b = make_tensor(Shape, _1, _4>{}); 883 | 884 | Tensor sV0R = get_half_V(sK0); 885 | Tensor sV1R = get_half_V(sK1); 886 | 887 | // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0 888 | NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready); 889 | wg1_bunch_0(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup); 890 | NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready); 891 | 892 | // Save rPb to sP, and issue rO1 += rP1b @ sV1R 893 | // We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations 894 | if constexpr (!IS_BLK0_LAST) { 895 | save_rPb_to_sP(rP1b, sP1, idx_in_warpgroup); 896 | } 897 | if constexpr (!IS_BLK0_LAST) { 898 | if constexpr (IS_BLK1_LAST) { 899 | fill_oob_V(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); 900 | cutlass::arch::fence_view_async_shared(); 901 | } 902 | warpgroup_cooperative_pv_gemm_localP(rP1b, sV1R, rO1, idx_in_warpgroup); 903 | if constexpr (!IS_BLK1_LAST) { 904 | // We use this proxy for making sP1 visible to the async proxy 905 | // We skip it if IS_BLK1_LAST, since in that case we have already put a fence 906 | cutlass::arch::fence_view_async_shared(); 907 | } 908 | } 909 | 910 | // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0 911 | NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready); 912 | if constexpr (IS_BLK0_LAST) { 913 | fill_oob_V(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup); 914 | cutlass::arch::fence_view_async_shared(); 915 | } 916 | warpgroup_cooperative_pv_gemm_remoteP(sP0, sV0R, rO1, idx_in_warpgroup); 917 | if constexpr (!IS_BLK0_LAST) { 918 | NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); 919 | } 920 | 921 | // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R 922 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { 923 | cute::warpgroup_wait<1>(); 924 | launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); 925 | } 926 | 927 | // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R 928 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { 929 | cute::warpgroup_wait<0>(); 930 | launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); 931 | } 932 | 933 | if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { 934 | // Issue rP1 = sQ @ sK1, wait 935 | warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); 936 | } 937 | 938 | // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise 939 | // nvcc cannot correctly analyse the loop, and will think that we are using accumulator 940 | // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized. 941 | // This is also the reason why we put QK^T here, instead of the first operation in the loop 942 | cute::warpgroup_wait<0>(); 943 | } 944 | 945 | // A helper function for determining the length of the causal mask for one q token 946 | __forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { 947 | int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; 948 | if (global_seq_q_idx < params.q_seq_per_hk) { 949 | int s_q_idx = global_seq_q_idx / params.q_head_per_hk; 950 | return params.s_q - s_q_idx - 1; 951 | } else { 952 | // Out-of-bound request, regard as no masks 953 | return 0; 954 | } 955 | } 956 | 957 | template 958 | __global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) 959 | flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { 960 | // grid shape: [ 961 | // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), 962 | // num_kv_heads, 963 | // num_sm_parts 964 | // ] 965 | // An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx). 966 | // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) 967 | // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). 968 | 969 | const int m_block_idx = blockIdx.x; 970 | const int k_head_idx = blockIdx.y; 971 | const int partition_idx = blockIdx.z; 972 | const int warpgroup_idx = threadIdx.x / 128; 973 | const int idx_in_warpgroup = threadIdx.x % 128; 974 | 975 | // Define shared tensors 976 | extern __shared__ char wksp_buf[]; 977 | using SharedMemoryPlan = typename T::SharedMemoryPlan; 978 | SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); 979 | Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){}); 980 | Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){}); 981 | Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){}); 982 | Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){}); 983 | Tensor sP1 = flat_divide(sQ, Shape, Int>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile 984 | Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int{})); 985 | Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{})); 986 | Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int{})); 987 | Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int{})); 988 | char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1 989 | 990 | // Prefetch TMA descriptors 991 | if (threadIdx.x == 0) { 992 | cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); 993 | cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor()); 994 | cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); 995 | } 996 | 997 | // Define TMA stuffs 998 | Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _); 999 | TMABarrier* barriers_K0 = plan.barriers_K0; 1000 | TMABarrier* barriers_K1 = plan.barriers_K1; 1001 | TMABarrier* barrier_Q = &(plan.barrier_Q); 1002 | 1003 | // Initialize TMA barriers 1004 | if (threadIdx.x == 0) { 1005 | barrier_Q->init(1); 1006 | CUTLASS_PRAGMA_UNROLL 1007 | for (int i = 0; i < 9; ++i) { 1008 | barriers_K0[i].init(1); 1009 | barriers_K1[i].init(1); 1010 | } 1011 | cutlass::arch::fence_view_async_shared(); 1012 | } 1013 | __syncthreads(); 1014 | bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; 1015 | 1016 | // Programmatic Dependent Launch: Wait for the previous kernel to finish 1017 | cudaGridDependencySynchronize(); 1018 | 1019 | int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; 1020 | // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. 1021 | int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); 1022 | int begin_idx = tile_scheduler_metadata.x; 1023 | int begin_seqlen = tile_scheduler_metadata.y; 1024 | int end_idx = tile_scheduler_metadata.z; 1025 | int end_seqlen = tile_scheduler_metadata.w; 1026 | if (begin_idx >= params.b) return; 1027 | int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); 1028 | 1029 | // Copy the first Q 1030 | launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); 1031 | 1032 | #pragma unroll 1 1033 | for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { 1034 | constexpr int kBlockN = T::PAGE_BLOCK_SIZE; 1035 | const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; 1036 | int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); 1037 | const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; 1038 | int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); 1039 | const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); 1040 | 1041 | int rRightBorderForQSeq[2]; 1042 | if (params.is_causal) { 1043 | // The causal mask looks like: 1044 | // XXXX 1045 | // XXXX 1046 | // ... 1047 | // XXXX 1048 | // XXX 1049 | // XXX 1050 | // ... 1051 | // XXX 1052 | // XX 1053 | // XX 1054 | // ... 1055 | // XX 1056 | // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation. 1057 | // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks 1058 | // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling 1059 | int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); 1060 | end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); 1061 | 1062 | CUTLASS_PRAGMA_UNROLL 1063 | for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { 1064 | int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); 1065 | rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE); 1066 | } 1067 | } else { 1068 | rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k; 1069 | } 1070 | 1071 | // Define global tensors 1072 | using InputT = typename T::InputT; 1073 | InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1) 1074 | float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) 1075 | int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1) 1076 | 1077 | Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( 1078 | Shape, Int>{}, 1079 | make_stride(params.o_row_stride, _1{}) 1080 | )); 1081 | Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout< 1082 | Shape>, 1083 | Stride<_1> 1084 | >{}); 1085 | 1086 | // Copy K0 and K1 1087 | launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); 1088 | if (start_block_idx+1 < end_block_idx) { 1089 | launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); 1090 | launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); 1091 | } 1092 | 1093 | Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) 1094 | float rL[2]; 1095 | rL[0] = rL[1] = 0.0f; 1096 | 1097 | // Clear buffers 1098 | cute::fill(rO, 0.); 1099 | if (threadIdx.x < size(sM)) { 1100 | sM[threadIdx.x] = MAX_INIT_VAL_SM; 1101 | } 1102 | 1103 | // Wait for Q 1104 | barrier_Q->wait(cur_phase_Q); 1105 | cur_phase_Q ^= 1; 1106 | 1107 | Tensor rQ8 = make_tensor(Shape, _1, _4>{}); 1108 | retrieve_rP_from_sP(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup); 1109 | 1110 | if (warpgroup_idx == 0) { 1111 | // Warpgroup 0 1112 | Tensor rP0 = make_tensor((typename T::rP0Layout){}); 1113 | 1114 | // NOTE We don't use the pipelined version of Q K^T here since it leads 1115 | // to a slow-down (or even register spilling, thanks to the great NVCC) 1116 | // Wait for K0 1117 | CUTLASS_PRAGMA_UNROLL 1118 | for (int i = 0; i < 9; ++i) { 1119 | if (idx_in_warpgroup == 0) 1120 | barriers_K0[i].arrive_and_expect_tx(64*64*2); 1121 | barriers_K0[i].wait(cur_phase_K0); 1122 | } 1123 | cur_phase_K0 ^= 1; 1124 | 1125 | // Issue P0 = Q @ K0^T, wait 1126 | warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); 1127 | // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 1128 | NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); 1129 | cute::warpgroup_wait<0>(); 1130 | 1131 | #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ 1132 | wg0_subroutine( \ 1133 | tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ 1134 | rQ8, rP0, rO, rL, rRightBorderForQSeq, \ 1135 | barriers_K0, barriers_K1, cur_phase_K0, \ 1136 | tma_params, params, \ 1137 | block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ 1138 | ); 1139 | 1140 | int block_idx = start_block_idx; 1141 | #pragma unroll 1 1142 | for (; block_idx < end_block_idx-2; block_idx += 2) { 1143 | LAUNCH_WG0_SUBROUTINE(false, false); 1144 | } 1145 | 1146 | if (block_idx+1 < end_block_idx) { 1147 | LAUNCH_WG0_SUBROUTINE(false, true); 1148 | } else if (block_idx < end_block_idx) { 1149 | LAUNCH_WG0_SUBROUTINE(true, false); 1150 | } 1151 | 1152 | } else { 1153 | // Warpgroup 1 1154 | Tensor rP1 = make_tensor((typename T::rP0Layout){}); 1155 | 1156 | if (start_block_idx+1 < end_block_idx) { 1157 | // Issue rP1 = sQ @ sK1, wait 1158 | warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); 1159 | cute::warpgroup_wait<0>(); 1160 | } 1161 | 1162 | #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \ 1163 | wg1_subroutine( \ 1164 | tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ 1165 | rQ8, rP1, rO, rL, rRightBorderForQSeq, \ 1166 | barriers_K0, barriers_K1, cur_phase_K1, \ 1167 | tma_params, params, \ 1168 | block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ 1169 | ); 1170 | 1171 | int block_idx = start_block_idx; 1172 | #pragma unroll 1 1173 | for (; block_idx < end_block_idx-3; block_idx += 2) { 1174 | LAUNCH_WG1_SUBROUTINE(false, false, false); 1175 | } 1176 | 1177 | if (block_idx+2 < end_block_idx) { 1178 | LAUNCH_WG1_SUBROUTINE(false, false, true); 1179 | block_idx += 2; 1180 | LAUNCH_WG1_SUBROUTINE(true, false, false); 1181 | } else if (block_idx+1 < end_block_idx) { 1182 | LAUNCH_WG1_SUBROUTINE(false, true, false); 1183 | } else if (block_idx < end_block_idx) { 1184 | LAUNCH_WG1_SUBROUTINE(true, false, false); 1185 | } 1186 | } 1187 | 1188 | // Reduce rL across threads within the same warp 1189 | rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); 1190 | rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); 1191 | rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); 1192 | rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); 1193 | 1194 | // Reduce rL across warpgroups 1195 | int my_row = get_AorC_row_idx(0, idx_in_warpgroup); 1196 | if (idx_in_warpgroup%4 == 0) { 1197 | sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0]; 1198 | sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1]; 1199 | } 1200 | __syncthreads(); 1201 | if (warpgroup_idx == 0) { 1202 | rL[0] += sL_reduction_wksp[my_row + 64]; 1203 | rL[1] += sL_reduction_wksp[my_row + 8 + 64]; 1204 | } else { 1205 | if (idx_in_warpgroup%4 == 0) { 1206 | sL_reduction_wksp[my_row] += rL[0]; 1207 | sL_reduction_wksp[my_row + 8] += rL[1]; 1208 | } 1209 | __syncwarp(); 1210 | rL[0] = sL_reduction_wksp[my_row]; 1211 | rL[1] = sL_reduction_wksp[my_row+8]; 1212 | } 1213 | 1214 | // Prune out when rL is 0.0f or NaN 1215 | // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads 1216 | // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error. 1217 | // When this happens, we set rL to 1.0f. This aligns with the old version 1218 | // of the MLA kernel. 1219 | CUTLASS_PRAGMA_UNROLL 1220 | for (int i = 0; i < 2; ++i) 1221 | rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; 1222 | 1223 | // Copy Q for the next batch 1224 | if (batch_idx+1 <= end_idx) { 1225 | launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); 1226 | } else { 1227 | // Allow the next kernel (the combine kernel) to launch 1228 | // The next kernel MUST be the combine kernel 1229 | cudaTriggerProgrammaticLaunchCompletion(); 1230 | } 1231 | 1232 | int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M); 1233 | if (is_no_split) { 1234 | store_o(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); 1235 | 1236 | int i = threadIdx.x; 1237 | if (i < num_valid_seq_q) { 1238 | float cur_L = sL_reduction_wksp[i]; 1239 | gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E; 1240 | } 1241 | 1242 | cute::tma_store_wait<0>(); 1243 | } else { 1244 | // Don't use __ldg because of PDL and instruction reordering 1245 | int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; 1246 | float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) 1247 | float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) 1248 | Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< 1249 | Shape, Int>, 1250 | Stride, _1> 1251 | >{}); 1252 | Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout< 1253 | Shape>, 1254 | Stride<_1> 1255 | >{}); 1256 | store_o(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); 1257 | 1258 | int i = threadIdx.x; 1259 | if (i < num_valid_seq_q) { 1260 | float cur_L = sL_reduction_wksp[i]; 1261 | gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i); 1262 | } 1263 | 1264 | cute::tma_store_wait<0>(); 1265 | } 1266 | 1267 | if (batch_idx != end_idx) 1268 | __syncthreads(); 1269 | } 1270 | } 1271 | 1272 | 1273 | template 1274 | void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { 1275 | using T = Traits; 1276 | auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); 1277 | auto tma_Q = cute::make_tma_copy( 1278 | SM90_TMA_LOAD{}, 1279 | make_tensor( 1280 | make_gmem_ptr((InputT*)params.q_ptr), 1281 | make_layout( 1282 | shape_Q, 1283 | make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride) 1284 | ) 1285 | ), 1286 | tile_to_shape( 1287 | GMMA::Layout_K_SW128_Atom{}, 1288 | Shape, Int>{} 1289 | ) 1290 | ); 1291 | auto shape_K = make_shape(Int{}, Int{}, params.h_k, params.num_blocks); 1292 | auto tma_K = cute::make_tma_copy( 1293 | SM90_TMA_LOAD{}, 1294 | make_tensor( 1295 | make_gmem_ptr((InputT*)params.k_ptr), 1296 | make_layout( 1297 | shape_K, 1298 | make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride) 1299 | ) 1300 | ), 1301 | tile_to_shape( 1302 | GMMA::Layout_K_SW128_Atom{}, 1303 | Layout< 1304 | Shape, Int<64>>, 1305 | Stride, _1> 1306 | >{} 1307 | ) 1308 | ); 1309 | auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b); 1310 | auto tma_O = cute::make_tma_copy( 1311 | SM90_TMA_STORE{}, 1312 | make_tensor( 1313 | make_gmem_ptr((InputT*)params.o_ptr), 1314 | make_layout( 1315 | shape_O, 1316 | make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride) 1317 | ) 1318 | ), 1319 | tile_to_shape( 1320 | GMMA::Layout_K_SW128_Atom{}, 1321 | Shape, Int>{} 1322 | ) 1323 | ); 1324 | TmaParams tma_params = { 1325 | shape_Q, tma_Q, 1326 | shape_K, tma_K, 1327 | shape_O, tma_O 1328 | }; 1329 | auto mla_kernel = &flash_fwd_splitkv_mla_kernel; 1330 | constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); 1331 | CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 1332 | 1333 | // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) 1334 | const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); 1335 | cudaLaunchAttribute mla_kernel_attributes[1]; 1336 | mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; 1337 | mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; 1338 | cudaLaunchConfig_t mla_kernel_config = { 1339 | dim3(num_m_block, params.h_k, params.num_sm_parts), 1340 | dim3(T::NUM_THREADS, 1, 1), 1341 | smem_size, 1342 | stream, 1343 | mla_kernel_attributes, 1344 | 1 1345 | }; 1346 | cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); 1347 | CHECK_CUDA_KERNEL_LAUNCH(); 1348 | } 1349 | 1350 | template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 1351 | 1352 | #ifndef FLASH_MLA_DISABLE_FP16 1353 | template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 1354 | #endif 1355 | -------------------------------------------------------------------------------- /csrc/kernels/splitkv_mla.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "params.h" 4 | 5 | template 6 | void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); 7 | -------------------------------------------------------------------------------- /csrc/kernels/traits.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "config.h" 9 | 10 | using TMABarrier = cutlass::arch::ClusterTransactionBarrier; 11 | using namespace cute; 12 | 13 | template 14 | struct Traits { 15 | using InputT = InputT_; 16 | 17 | static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; 18 | static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; 19 | static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; 20 | static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; 21 | 22 | static constexpr int NUM_THREADS = 256; 23 | 24 | static_assert(std::is_same_v || std::is_same_v); 25 | 26 | using TiledMMA_QK_sQ = decltype(make_tiled_mma( 27 | GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), 28 | Layout>{} 29 | )); 30 | 31 | using TiledMMA_QK_rQ = decltype(make_tiled_mma( 32 | GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), 33 | Layout>{} 34 | )); 35 | 36 | using TiledMMA_PV_LocalP = decltype(make_tiled_mma( 37 | GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), 38 | Layout>{} 39 | )); 40 | 41 | using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( 42 | GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), 43 | Layout>{} 44 | )); 45 | 46 | using SmemLayoutQ = decltype(tile_to_shape( 47 | GMMA::Layout_K_SW128_Atom{}, 48 | Shape, Int>{} 49 | )); 50 | 51 | using SmemLayoutK = decltype(tile_to_shape( 52 | GMMA::Layout_K_SW128_Atom{}, 53 | Shape, Int>{} 54 | )); 55 | 56 | using SmemLayoutV = decltype(composition( 57 | SmemLayoutK{}, 58 | make_layout(Shape, Int>{}, GenRowMajor{}) 59 | )); // A transposed version of SmemLayoutK 60 | 61 | using SmemLayoutP0 = decltype(tile_to_shape( 62 | GMMA::Layout_K_SW128_Atom{}, 63 | Shape, Int>{} 64 | )); 65 | 66 | using rP0Layout = decltype(layout(partition_fragment_C( 67 | TiledMMA_QK_sQ{}, 68 | Shape, Int>{} 69 | ))); 70 | 71 | struct SharedMemoryPlan { 72 | cute::array_aligned> smem_sQ; 73 | cute::array_aligned> smem_sK0; 74 | cute::array_aligned> smem_sK1; 75 | cute::array_aligned> smem_sP0; 76 | cute::array_aligned smem_sM; 77 | cute::array_aligned sL_reduction_wksp; 78 | cute::array_aligned smem_sScale0; 79 | cute::array_aligned smem_sScale1; 80 | TMABarrier barriers_K0[HEAD_DIM_K/64]; 81 | TMABarrier barriers_K1[HEAD_DIM_K/64]; 82 | TMABarrier barrier_Q; 83 | }; 84 | 85 | }; 86 | 87 | template< 88 | typename ShapeQ, typename TMA_Q, 89 | typename ShapeK, typename TMA_K, 90 | typename ShapeO, typename TMA_O 91 | > 92 | struct TmaParams { 93 | ShapeQ shape_Q; 94 | TMA_Q tma_Q; 95 | ShapeK shape_K; 96 | TMA_K tma_K; 97 | ShapeO shape_O; 98 | TMA_O tma_O; 99 | }; 100 | 101 | enum NamedBarriers : int { 102 | sScale0Ready = 0, 103 | sScale1Ready = 1, 104 | sP0Ready = 2, 105 | rO1sP0sV0RIssued = 3, 106 | sMInitialized = 4, 107 | }; 108 | -------------------------------------------------------------------------------- /csrc/kernels/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define CHECK_CUDA(call) \ 4 | do { \ 5 | cudaError_t status_ = call; \ 6 | if (status_ != cudaSuccess) { \ 7 | fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ 8 | exit(1); \ 9 | } \ 10 | } while(0) 11 | 12 | #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) 13 | 14 | 15 | #define FLASH_ASSERT(cond) \ 16 | do { \ 17 | if (not (cond)) { \ 18 | fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ 19 | exit(1); \ 20 | } \ 21 | } while(0) 22 | 23 | 24 | #define FLASH_DEVICE_ASSERT(cond) \ 25 | do { \ 26 | if (not (cond)) { \ 27 | printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ 28 | asm("trap;"); \ 29 | } \ 30 | } while(0) 31 | 32 | #define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } 33 | -------------------------------------------------------------------------------- /docs/20250422-new-kernel-deep-dive.md: -------------------------------------------------------------------------------- 1 | # A Deep-Dive Into the New Flash MLA Kernel 2 | 3 | In the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) of the Flash MLA kernel, we have achieved impressive performance: 3000 GB/s in memory-intensive settings and 580 TFlops in compute-bound settings. Now, we're pushing these numbers even further, reaching up to 660 TFlops. 4 | 5 | In this blog, we present a deep dive into the new kernel, explaining the optimizations and techniques behind this performance boost. We'll first explain why the MLA kernel is compute-bound despite being a decoding-stage attention kernel, then discuss our high-level kernel schedule design, and finally cover the technical details of the new kernel. 6 | 7 | ## A Theoretical Analysis of the MLA Algorithm 8 | 9 | GPU kernels can be classified as either compute-bound (limited by floating-point operations per second, FLOPs) or memory-bound (limited by memory bandwidth). To identify the kernel's bottleneck, we calculate the ratio of FLOPs to memory bandwidth (FLOPs/byte) and compare it with the GPU's capacity. 10 | 11 | Assume the number of q heads is $h_q$, the number of q tokens per request is $s_q$ (should be 1 if MTP / speculative decoding is disabled), the number of kv tokens per request is $s_k\ (s_k \gg h_q s_q)$, and the head dimensions of K and V are $d_k$ and $d_v$ respectively. The number of FLOPs is roughly $2 (h_q s_q \cdot d_k \cdot s_k + h_q s_q \cdot s_k \cdot d_v) = 2 h_q s_q s_k (d_k+d_v)$, and the memory access volume (in bytes) is $\mathop{\text{sizeof}}(\text{bfloat16}) \times (h_q s_q d_k + s_k d_k + h_q s_q d_v) \approx 2s_k d_k$. Thus, the compute-memory ratio is $h_q s_q \cdot \frac{d_k+d_v}{d_k} \approx 2 h_q s_q$. 12 | 13 | An NVIDIA H800 SXM5 GPU has a peak memory bandwidth of 3.35 TB/s and peak FLOPs of 990 TFlops. However, due to throttling (reducing to ~1600 MHz in our case), the practical peak FLOPs drops to ~865 TFlops. Therefore, when $h_qs_q \ge \frac{1}{2} \cdot \frac{865}{3.35} = 128$, the kernel is compute-bound; otherwise, it's memory-bound. 14 | 15 | According to [the overview of DeepSeek's Online Inference System](https://github.com/deepseek-ai/open-infra-index/blob/main/202502OpenSourceWeek/day_6_one_more_thing_deepseekV3R1_inference_system_overview.md), we don't use Tensor Parallel for decoding instances, meaning $h_q$ is 128 and the kernel is compute-bound. Thus, we need to optimize the kernel for compute-bound settings. 16 | 17 | ## High-Level Design of the New Kernel 18 | 19 | To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule." 20 | 21 | [FlashAttention-3's paper](https://arxiv.org/abs/2407.08608) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possibility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation. 22 | 23 | (You might pause here to ponder - perhaps you can find a better solution than ours!) 24 | 25 | Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows: 26 | 27 | 0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0). 28 | 1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$. 29 | 2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$. 30 | 3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$. 31 | 4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$. 32 | 5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$. 33 | 6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$. 34 | 7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$. 35 | 8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$. 36 | 9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$. 37 | 10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$. 38 | 11. [0] Update $\vec o_L \gets \vec o_L \cdot scale_1 + \vec p_1 V_{1L}$. 39 | 40 | Note: We assume one q head for simplicity, so $\vec q$ and $\vec o$ are vectors. Bracketed numbers indicate the warpgroup performing the operation. Assume $\vec o_L$ resides in warpgroup 0's register and $\vec o_R$ resides in warpgroup 1's register. 41 | 42 | This schedule can be viewed as a "ping-pong" variant using one output matrix—we call it "seesaw" scheduling. It's mathematically equivalent to FlashAttention's online softmax algorithm. This schedule allows us to overlap CUDA Core and Tensor Core operations by interleaving the two warpgroups, and also allows us to overlap memory access with computation since we can launch the corresponding Tensor Memory Accelerator (TMA) instructions right after data is no longer needed. 43 | 44 | The complete schedule is shown below (remember that in MLA, $K$ and $V$ are the same with different names): 45 | 46 | ![MLA Kernel Sched](assets/MLA%20Kernel%20Sched.drawio.svg) 47 | 48 | ## Discussion of Technical Details 49 | 50 | This section covers technical details of the new kernel. 51 | 52 | First, although the kernel targets compute-bound scenarios (where memory bandwidth isn't the bottleneck), we can't ignore memory latency. If the data is not ready when we want to use it, we have to wait. To solve this problem, we employ the following techniques: 53 | 54 | - **Fine-grained TMA copy - GEMM pipelining:** For a $64 \times 576$ K block, we launch 9 TMA copies (each moving a $64 \times 64$ block). GEMM operations begin as soon as each TMA copy completes (When the first TMA copy is done, we can start the first GEMM operation, and so on), improving memory latency tolerance. 55 | - **Cache hints:** Using `cute::TMA::CacheHintSm90::EVICT_FIRST` for TMA copies improves L2 cache hit rates, as shown by experiments. 56 | 57 | These optimizations achieve up to 80% Tensor Core utilization (of the throttled theoretical peak) and 3 TB/s memory bandwidth on an H800 SXM5 GPU. While slightly slower (~2%) than the old ping-pong buffer version in memory-bound settings, this is acceptable. 58 | 59 | Other performance improvements include: 60 | - **Programmatic Dependent Launch.** We use [programmatic dependent launch](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) to overlap `splitkv_mla` and `combine` kernels. 61 | - **Tile Scheduler.** We implement a tile scheduler to allocate jobs (requests and blocks) to SMs. This ensures a balanced load across SMs. 62 | 63 | ## Acknowledgements 64 | 65 | FlashMLA's algorithm and scheduling are inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work. 66 | 67 | ## Citation 68 | 69 | ```bibtex 70 | @misc{flashmla2025, 71 | title={FlashMLA: Efficient MLA decoding kernels}, 72 | author={Jiashi Li, Shengyu Liu}, 73 | year={2025}, 74 | publisher = {GitHub}, 75 | howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/assets/MLA Kernel Sched.drawio.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 | rP0 = sQ @ sK0 21 |
22 |
23 |
24 |
25 | 26 | rP0 = sQ @ sK0 27 | 28 |
29 |
30 |
31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
41 |
42 |
43 | rP1 = sQ @ sK1 44 |
45 |
46 |
47 |
48 | 49 | rP1 = sQ @ sK1 50 | 51 |
52 |
53 |
54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
64 |
65 |
66 |
67 |
68 | 69 | Get sScale0 70 | 71 |
72 |
73 | 74 | Update sM 75 | 76 |
77 | 78 | rPb = rP0 = Softmax(rP0) 79 | 80 |
81 | 82 | rO0 = Scale(rO0) 83 | 84 |
85 |
86 | 87 | Update rL 88 | 89 |
90 |
91 |
92 |
93 |
94 |
95 | 96 | Get sScale0... 97 | 98 |
99 |
100 |
101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |
111 |
112 |
113 | Issue 114 |
115 |
116 |
117 |
118 | 119 | Issue 120 | 121 |
122 |
123 |
124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 |
134 |
135 |
136 | rO0 += rPb @ sV0L 137 |
138 |
139 |
140 |
141 | 142 | rO0 += rPb @ sV0L 143 | 144 |
145 |
146 |
147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 |
157 |
158 |
159 |
160 |
161 | 162 | Get sScale1 163 | 164 |
165 |
166 | 167 | Update sM 168 | 169 |
170 | 171 | rP1b = Softmax(rP1 172 | 173 | ) 174 | 175 | 176 |
177 | 178 | rO1 = Scale(rO1) 179 | 180 | 181 |
182 |
183 | 184 | 185 | Update rL 186 | 187 | 188 |
189 |
190 |
191 |
192 |
193 |
194 | 195 | Get sScale1... 196 | 197 |
198 |
199 |
200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 |
214 |
215 |
216 | Issue 217 |
218 |
219 |
220 |
221 | 222 | Issue 223 | 224 |
225 |
226 |
227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 |
237 |
238 |
239 | rO1 += rP1b @ sV1R 240 |
241 |
242 |
243 |
244 | 245 | rO1 += rP1b @ sV1R 246 | 247 |
248 |
249 |
250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 |
260 |
261 |
262 | rO0 = Scale(rO0) 263 |
264 |
265 |
266 |
267 | 268 | rO0 = Scale(rO0) 269 | 270 |
271 |
272 |
273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 |
287 |
288 |
289 | rO0 += sP1 @ sV1L 290 |
291 |
292 |
293 |
294 | 295 | rO0 += sP1 @ sV1L 296 | 297 |
298 |
299 |
300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 |
310 |
311 |
312 | rO1 += sP0 @ sV0R 313 |
314 |
315 |
316 |
317 | 318 | rO1 += sP0 @ sV0R 319 | 320 |
321 |
322 |
323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 |
349 |
350 |
351 | sP0 = Scale(rP0) 352 |
353 |
354 |
355 |
356 | 357 | sP0 = Scale(rP0) 358 | 359 |
360 |
361 |
362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 |
370 |
371 |
372 | Tensor 373 |
374 |
375 |
376 |
377 | 378 | Tensor 379 | 380 |
381 |
382 |
383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 |
391 |
392 |
393 | CUDA 394 |
395 |
396 |
397 |
398 | 399 | CUDA 400 | 401 |
402 |
403 |
404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 |
412 |
413 |
414 | CUDA 415 |
416 |
417 |
418 |
419 | 420 | CUDA 421 | 422 |
423 |
424 |
425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 |
433 |
434 |
435 | Tensor 436 |
437 |
438 |
439 |
440 | 441 | Tensor 442 | 443 |
444 |
445 |
446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 |
454 |
455 |
456 | Warpgroup 0 457 |
458 |
459 |
460 |
461 | 462 | Warpgroup 0 463 | 464 |
465 |
466 |
467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 |
475 |
476 |
477 | Warpgroup 1 478 |
479 |
480 |
481 |
482 | 483 | Warpgroup 1 484 | 485 |
486 |
487 |
488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 |
505 |
506 |
507 | Issue 508 |
509 |
510 |
511 |
512 | 513 | Issue 514 | 515 |
516 |
517 |
518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 |
540 |
541 |
542 | Issue 543 |
544 |
545 |
546 |
547 | 548 | Issue 549 | 550 |
551 |
552 |
553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 |
563 |
564 |
565 | Pipelined TMA wait and issue 566 |
567 |
568 |
569 |
570 | 571 | Pipelined TMA wait and issue 572 | 573 |
574 |
575 |
576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 |
590 |
591 |
592 | Pipelined TMA wait and issue 593 |
594 |
595 |
596 |
597 | 598 | Pipelined TMA wait and issue 599 | 600 |
601 |
602 |
603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 |
613 |
614 |
615 | sP1 = rP1b 616 |
617 |
618 |
619 |
620 | 621 | sP1 = rP1b 622 | 623 |
624 |
625 |
626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 |
637 |
638 |
639 | wg0-bunch-0 640 |
641 |
642 |
643 |
644 | 645 | wg0-bunch-0 646 | 647 |
648 |
649 |
650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 |
661 |
662 |
663 | wg1-bunch-0 664 |
665 |
666 |
667 |
668 | 669 | wg1-bunch-0 670 | 671 |
672 |
673 |
674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 |
684 |
685 |
686 | Issue TMA (nxt V0L) 687 |
688 |
689 |
690 |
691 | 692 | Issue TMA (nxt V0L) 693 | 694 |
695 |
696 |
697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 |
707 |
708 |
709 | Issue TMA (nxt V1L) 710 |
711 |
712 |
713 |
714 | 715 | Issue TMA (nxt V1L) 716 | 717 |
718 |
719 |
720 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | 736 | 737 |
738 |
739 |
740 | Issue TMA (nxt V1R) 741 |
742 |
743 |
744 |
745 | 746 | Issue TMA (nxt V1R) 747 | 748 |
749 |
750 |
751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 |
761 |
762 |
763 | Issue TMA (nxt V0R) 764 |
765 |
766 |
767 |
768 | 769 | Issue TMA (nxt V0R) 770 | 771 |
772 |
773 |
774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 792 | 793 |
794 |
795 |
796 | sXX: Stored on shared memory 797 |
798 | rXX: Stored on register file 799 |
800 |
801 |
802 |
803 |
804 | 805 | sXX: Stored on shared memory... 806 | 807 |
808 |
809 |
810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 |
827 |
828 |
829 | 830 | Loop boundary in our code 831 | 832 |
833 | 834 | (plz refer to comments in `wg1_subroutine`) 835 | 836 |
837 |
838 |
839 |
840 |
841 | 842 | Loop boundary in our code... 843 | 844 |
845 |
846 |
847 |
848 | 849 | 850 | 851 | 852 | Text is not SVG - cannot display 853 | 854 | 855 | 856 |
-------------------------------------------------------------------------------- /flash_mla/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from flash_mla.flash_mla_interface import ( 4 | get_mla_metadata, 5 | flash_mla_with_kvcache, 6 | ) 7 | -------------------------------------------------------------------------------- /flash_mla/flash_mla_interface.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | import flash_mla_cuda 6 | 7 | 8 | def get_mla_metadata( 9 | cache_seqlens: torch.Tensor, 10 | num_heads_per_head_k: int, 11 | num_heads_k: int, 12 | ) -> Tuple[torch.Tensor, torch.Tensor]: 13 | """ 14 | Arguments: 15 | cache_seqlens: (batch_size), dtype torch.int32. 16 | num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. 17 | num_heads_k: num_heads_k. 18 | 19 | Returns: 20 | tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. 21 | num_splits: (batch_size + 1), dtype torch.int32. 22 | """ 23 | return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) 24 | 25 | 26 | def flash_mla_with_kvcache( 27 | q: torch.Tensor, 28 | k_cache: torch.Tensor, 29 | block_table: torch.Tensor, 30 | cache_seqlens: torch.Tensor, 31 | head_dim_v: int, 32 | tile_scheduler_metadata: torch.Tensor, 33 | num_splits: torch.Tensor, 34 | softmax_scale: Optional[float] = None, 35 | causal: bool = False, 36 | ) -> Tuple[torch.Tensor, torch.Tensor]: 37 | """ 38 | Arguments: 39 | q: (batch_size, seq_len_q, num_heads_q, head_dim). 40 | k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). 41 | block_table: (batch_size, max_num_blocks_per_seq), torch.int32. 42 | cache_seqlens: (batch_size), torch.int32. 43 | head_dim_v: Head dimension of v. 44 | tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. 45 | num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. 46 | softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). 47 | causal: bool. Whether to apply causal attention mask. 48 | 49 | Returns: 50 | out: (batch_size, seq_len_q, num_heads_q, head_dim_v). 51 | softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. 52 | """ 53 | if softmax_scale is None: 54 | softmax_scale = q.shape[-1] ** (-0.5) 55 | out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( 56 | q, 57 | k_cache, 58 | head_dim_v, 59 | cache_seqlens, 60 | block_table, 61 | softmax_scale, 62 | causal, 63 | tile_scheduler_metadata, 64 | num_splits, 65 | ) 66 | return out, softmax_lse 67 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from datetime import datetime 4 | import subprocess 5 | 6 | from setuptools import setup, find_packages 7 | 8 | from torch.utils.cpp_extension import ( 9 | BuildExtension, 10 | CUDAExtension, 11 | IS_WINDOWS, 12 | ) 13 | 14 | 15 | def append_nvcc_threads(nvcc_extra_args): 16 | nvcc_threads = os.getenv("NVCC_THREADS") or "32" 17 | return nvcc_extra_args + ["--threads", nvcc_threads] 18 | 19 | 20 | def get_features_args(): 21 | features_args = [] 22 | DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] 23 | if DISABLE_FP16: 24 | features_args.append("-DFLASH_MLA_DISABLE_FP16") 25 | return features_args 26 | 27 | 28 | subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) 29 | 30 | cc_flag = [] 31 | cc_flag.append("-gencode") 32 | cc_flag.append("arch=compute_90a,code=sm_90a") 33 | 34 | this_dir = os.path.dirname(os.path.abspath(__file__)) 35 | 36 | if IS_WINDOWS: 37 | cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"] 38 | else: 39 | cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"] 40 | 41 | ext_modules = [] 42 | ext_modules.append( 43 | CUDAExtension( 44 | name="flash_mla_cuda", 45 | sources=[ 46 | "csrc/flash_api.cpp", 47 | "csrc/kernels/get_mla_metadata.cu", 48 | "csrc/kernels/mla_combine.cu", 49 | "csrc/kernels/splitkv_mla.cu", 50 | ], 51 | extra_compile_args={ 52 | "cxx": cxx_args + get_features_args(), 53 | "nvcc": append_nvcc_threads( 54 | [ 55 | "-O3", 56 | "-std=c++17", 57 | "-DNDEBUG", 58 | "-D_USE_MATH_DEFINES", 59 | "-Wno-deprecated-declarations", 60 | "-U__CUDA_NO_HALF_OPERATORS__", 61 | "-U__CUDA_NO_HALF_CONVERSIONS__", 62 | "-U__CUDA_NO_HALF2_OPERATORS__", 63 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 64 | "--expt-relaxed-constexpr", 65 | "--expt-extended-lambda", 66 | "--use_fast_math", 67 | "--ptxas-options=-v,--register-usage-level=10" 68 | ] 69 | + cc_flag 70 | ) + get_features_args(), 71 | }, 72 | include_dirs=[ 73 | Path(this_dir) / "csrc", 74 | Path(this_dir) / "csrc" / "cutlass" / "include", 75 | ], 76 | ) 77 | ) 78 | 79 | 80 | try: 81 | cmd = ['git', 'rev-parse', '--short', 'HEAD'] 82 | rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() 83 | except Exception as _: 84 | now = datetime.now() 85 | date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") 86 | rev = '+' + date_time_str 87 | 88 | 89 | setup( 90 | name="flash_mla", 91 | version="1.0.0" + rev, 92 | packages=find_packages(include=['flash_mla']), 93 | ext_modules=ext_modules, 94 | cmdclass={"build_ext": BuildExtension}, 95 | ) 96 | -------------------------------------------------------------------------------- /tests/test_flash_mla.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | 5 | import torch 6 | import triton 7 | 8 | from flash_mla import flash_mla_with_kvcache, get_mla_metadata 9 | 10 | 11 | def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): 12 | query = query.float() 13 | key = key.float() 14 | value = value.float() 15 | key = key.repeat_interleave(h_q // h_kv, dim=0) 16 | value = value.repeat_interleave(h_q // h_kv, dim=0) 17 | attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) 18 | if is_causal: 19 | s_q = query.shape[-2] 20 | s_k = key.shape[-2] 21 | attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) 22 | temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) 23 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 24 | attn_bias.to(query.dtype) 25 | attn_weight += attn_bias 26 | lse = attn_weight.logsumexp(dim=-1) 27 | attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) 28 | return attn_weight @ value, lse 29 | 30 | 31 | def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: 32 | x, y = x.double(), y.double() 33 | RMSE = ((x - y) * (x - y)).mean().sqrt().item() 34 | cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) 35 | amax_diff = (x - y).abs().max().item() 36 | # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") 37 | assert cos_diff < 1e-5 38 | 39 | 40 | @torch.inference_mode() 41 | def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): 42 | print( 43 | f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" 44 | ) 45 | 46 | cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) 47 | if varlen: 48 | for i in range(b): 49 | cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) 50 | total_seqlens = cache_seqlens.sum().item() 51 | mean_seqlens = cache_seqlens.float().mean().int().item() 52 | max_seqlen = cache_seqlens.max().item() 53 | max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 54 | # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") 55 | 56 | q = torch.randn(b, s_q, h_q, d) 57 | block_size = 64 58 | block_table = torch.arange( 59 | b * max_seqlen_pad // block_size, dtype=torch.int32 60 | ).view(b, max_seqlen_pad // block_size) 61 | blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) 62 | for i in range(b): 63 | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( 64 | float("nan") 65 | ) 66 | blocked_v = blocked_k[..., :dv] 67 | 68 | tile_scheduler_metadata, num_splits = get_mla_metadata( 69 | cache_seqlens, s_q * h_q // h_kv, h_kv 70 | ) 71 | 72 | def flash_mla(): 73 | return flash_mla_with_kvcache( 74 | q, 75 | blocked_k, 76 | block_table, 77 | cache_seqlens, 78 | dv, 79 | tile_scheduler_metadata, 80 | num_splits, 81 | causal=causal, 82 | ) 83 | 84 | def ref_mla(): 85 | out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) 86 | lse = torch.empty(b, h_q, s_q, dtype=torch.float32) 87 | for i in range(b): 88 | begin = i * max_seqlen_pad 89 | end = begin + cache_seqlens[i] 90 | O, LSE = scaled_dot_product_attention( 91 | q[i].transpose(0, 1), 92 | blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), 93 | blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), 94 | h_q=h_q, 95 | h_kv=h_kv, 96 | is_causal=causal, 97 | ) 98 | out[i] = O.transpose(0, 1) 99 | lse[i] = LSE 100 | return out, lse 101 | 102 | out_flash, lse_flash = flash_mla() 103 | out_torch, lse_torch = ref_mla() 104 | cal_diff(out_flash, out_torch, "out") 105 | cal_diff(lse_flash, lse_torch, "lse") 106 | 107 | t = triton.testing.do_bench(flash_mla) 108 | FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 109 | bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( 110 | torch.finfo(q.dtype).bits // 8 111 | ) 112 | print( 113 | f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" 114 | ) 115 | 116 | 117 | def main(torch_dtype): 118 | device = torch.device("cuda:0") 119 | torch.set_default_dtype(torch_dtype) 120 | torch.set_default_device(device) 121 | torch.cuda.set_device(device) 122 | torch.manual_seed(0) 123 | random.seed(0) 124 | 125 | h_kv = 1 126 | d, dv = 576, 512 127 | causal = True 128 | 129 | for b in [128]: 130 | for s in [4096, 8192, 16384]: 131 | for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 132 | for s_q in [1, 2]: # MTP = 1, 2 133 | for varlen in [False, True]: 134 | test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument( 140 | "--dtype", 141 | type=str, 142 | choices=["bf16", "fp16"], 143 | default="bf16", 144 | help="Data type to use for testing (bf16 or fp16)", 145 | ) 146 | 147 | args = parser.parse_args() 148 | 149 | torch_dtype = torch.bfloat16 150 | if args.dtype == "fp16": 151 | torch_dtype = torch.float16 152 | 153 | main(torch_dtype) 154 | --------------------------------------------------------------------------------