├── flash_attn_v100 ├── __init__.py └── flash_attn_interface.py ├── include └── fused_mha.h ├── README.md ├── test.py ├── setup.py └── kernel ├── fused_mha_api.cpp └── fused_mha_kernel.cu /flash_attn_v100/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | 3 | from flash_attn_v100.flash_attn_interface import ( 4 | flash_attn_func, 5 | ) 6 | -------------------------------------------------------------------------------- /include/fused_mha.h: -------------------------------------------------------------------------------- 1 | void fused_mha_forward(const void *query_ptr, const void *key_ptr, const void *value_ptr, void *output_ptr, void *max_ptr, void *sum_ptr, 2 | int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream); 3 | 4 | void fused_mha_backward(const void *query_ptr, const void *key_ptr, const void *value_ptr, 5 | void *output_ptr, void *d_output_ptr, void *d_ptr, void *max_ptr, void *sum_ptr, 6 | void *d_query_ptr, void *d_key_ptr, void *d_value_ptr, int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flash_Attention_V100 2 | 3 | flash attention只支持Ampere架构以上的显卡,对于V100这个Volta架构的显卡并不支持,所以出于兴趣,我按照cutlass教程以及flash attention2的论文,写了这个适用于V100的版本,不过由于工作繁忙以及硬件条件限制,不能细致地进行性能调试,本Repo的性能并不能比得上pytorch的attention计算。当前forward的耗时相比于pytorch大约降低了40%,但是backward的耗时大约比pytorch多20%,两者相消。另外,该实现没有考虑边界条件,因此句子的长度要用right padding的方式,pad到32的倍数。这对正常训练并不会有影响,只需在计算loss时,将padding的地方忽略即可。 4 | 5 | ## 安装 6 | 在安装前,你需要确保: 7 | 8 | - PyTorch >= 2.0.1 9 | - CUDA >= 11.6 10 | - Linux OS 11 | - Cutlass源码 12 | 13 | 修改setup.py的146行,将这一行改为你下载的cutlass源码的位置 14 | 15 | ```py 16 | include_dirs=[ 17 | Path(this_dir) / "include", 18 | "/home/user/cutlass/include", 19 | ], 20 | ``` 21 | 22 | 修改完毕后,执行命令进行源码安装 23 | ```bash 24 | python setup.py install --user 25 | ``` 26 | 27 | ## 用法 28 | 29 | ```python 30 | from flash_attn_v100 import flash_attn_func 31 | q = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 32 | k = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 33 | v = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 34 | cuda_out = flash_attn_func(q, k, v, sm_scale, causal) 35 | ``` 36 | 37 | ## 参考 38 | - [Flash-Attention](https://github.com/Dao-AILab/flash-attention) 39 | - [CUTLASS](https://github.com/NVIDIA/cutlass) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from flash_attn_v100 import flash_attn_func 4 | Z, H, N_CTX, D_HEAD, causal, dtype = 2, 40, 2048, 128, True, torch.float16 5 | torch.manual_seed(20) 6 | q = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 7 | k = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 8 | v = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() 9 | sm_scale = 1 / 10 10 | dout = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1) 11 | begin = time.time() 12 | for i in range(1): 13 | q_transposed = q.transpose(1, 2) 14 | k_transposed = k.transpose(1, 2) 15 | v_transposed = v.transpose(1, 2) 16 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 17 | p = torch.matmul(q_transposed, k_transposed.transpose(2, 3)) * sm_scale 18 | if causal: 19 | p[:, :, M == 0] = float("-inf") 20 | p = torch.softmax(p.float(), dim=-1).half() 21 | ref_out = torch.matmul(p, v_transposed) 22 | ref_out = ref_out.transpose(1, 2) 23 | ref_out.backward(dout) 24 | ref_dv, v.grad = v.grad.clone(), None 25 | ref_dk, k.grad = k.grad.clone(), None 26 | ref_dq, q.grad = q.grad.clone(), None 27 | torch.cuda.synchronize(device="cuda:0") 28 | end = time.time() 29 | print(f"torch cost : {end - begin}") 30 | begin = time.time() 31 | # triton implementation 32 | for i in range(1): 33 | cuda_out = flash_attn_func(q, k, v, sm_scale, causal) 34 | cuda_out.backward(dout) 35 | dq, q.grad = q.grad.clone(), None 36 | dk, k.grad = k.grad.clone(), None 37 | dv, v.grad = v.grad.clone(), None 38 | torch.cuda.synchronize(device="cuda:0") 39 | end = time.time() 40 | print(f"triton cost : {end - begin}") 41 | # compare 42 | assert torch.allclose(ref_out, cuda_out, atol=1e-2, rtol=0) 43 | assert torch.allclose(ref_dq, dq, atol=1e-2, rtol=0) 44 | assert torch.allclose(ref_dk, dk, atol=1e-2, rtol=0) 45 | assert torch.allclose(ref_dv, dv, atol=1e-2, rtol=0) -------------------------------------------------------------------------------- /flash_attn_v100/flash_attn_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import os 4 | from typing import Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | # We need to import the CUDA kernels after importing torch 9 | import flash_attn_v100_cuda as flash_attn_cuda 10 | 11 | def _flash_attn_forward( 12 | q, k, v, softmax_scale, causal 13 | ): 14 | maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x 15 | q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 16 | out, softmax_max, softmax_sum = flash_attn_cuda.fwd( 17 | q, 18 | k, 19 | v, 20 | None, 21 | softmax_scale, 22 | causal 23 | ) 24 | return out, q, k, v, softmax_max, softmax_sum 25 | 26 | def _flash_attn_backward( 27 | dout, 28 | q, 29 | k, 30 | v, 31 | out, 32 | softmax_max, 33 | softmax_sum, 34 | dq, 35 | dk, 36 | dv, 37 | softmax_scale, 38 | causal, 39 | ): 40 | maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x 41 | # dq, dk, dv are allocated by us so they should already be contiguous 42 | dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] 43 | D = out * dout 44 | D = D.sum(-1) 45 | dq, dk, dv, = flash_attn_cuda.bwd( 46 | dout, 47 | q, 48 | k, 49 | v, 50 | out, 51 | D, 52 | softmax_sum, 53 | softmax_max, 54 | dq, 55 | dk, 56 | dv, 57 | softmax_scale, 58 | causal, 59 | ) 60 | return dq, dk, dv 61 | 62 | class FlashAttnFunc(torch.autograd.Function): 63 | @staticmethod 64 | def forward( 65 | ctx, q, k, v, softmax_scale, causal 66 | ): 67 | if softmax_scale is None: 68 | softmax_scale = q.shape[-1] ** (-0.5) 69 | out, q, k, v, softmax_max, softmax_sum = _flash_attn_forward( 70 | q, 71 | k, 72 | v, 73 | softmax_scale, 74 | causal=causal, 75 | ) 76 | ctx.save_for_backward(q, k, v, out, softmax_max, softmax_sum) 77 | ctx.softmax_scale = softmax_scale 78 | ctx.causal = causal 79 | return out 80 | 81 | @staticmethod 82 | def backward(ctx, dout, *args): 83 | q, k, v, out, softmax_max, softmax_sum = ctx.saved_tensors 84 | dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) 85 | _flash_attn_backward( 86 | dout, 87 | q, 88 | k, 89 | v, 90 | out, 91 | softmax_max, 92 | softmax_sum, 93 | dq, 94 | dk, 95 | dv, 96 | ctx.softmax_scale, 97 | ctx.causal, 98 | ) 99 | return dq, dk, dv, None, None 100 | 101 | def flash_attn_func( 102 | q, 103 | k, 104 | v, 105 | softmax_scale=None, 106 | causal=False, 107 | ): 108 | return FlashAttnFunc.apply( 109 | q, k, v, softmax_scale, causal 110 | ) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import sys 4 | import warnings 5 | import os 6 | import re 7 | import ast 8 | from pathlib import Path 9 | from packaging.version import parse, Version 10 | import platform 11 | 12 | from setuptools import setup, find_packages 13 | import subprocess 14 | 15 | import urllib.request 16 | import urllib.error 17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 18 | 19 | import torch 20 | from torch.utils.cpp_extension import ( 21 | BuildExtension, 22 | CppExtension, 23 | CUDAExtension, 24 | CUDA_HOME, 25 | ) 26 | 27 | # ninja build does not work unless include_dirs are abs path 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | 30 | PACKAGE_NAME = "flash_attn_v100" 31 | 32 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 33 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 34 | FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" 35 | SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 36 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 37 | FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" 38 | 39 | 40 | def get_platform(): 41 | """ 42 | Returns the platform name as used in wheel filenames. 43 | """ 44 | if sys.platform.startswith("linux"): 45 | return "linux_x86_64" 46 | elif sys.platform == "darwin": 47 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 48 | return f"macosx_{mac_version}_x86_64" 49 | elif sys.platform == "win32": 50 | return "win_amd64" 51 | else: 52 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 53 | 54 | 55 | def get_cuda_bare_metal_version(cuda_dir): 56 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 57 | output = raw_output.split() 58 | release_idx = output.index("release") + 1 59 | bare_metal_version = parse(output[release_idx].split(",")[0]) 60 | 61 | return raw_output, bare_metal_version 62 | 63 | 64 | def check_if_cuda_home_none(global_option: str) -> None: 65 | if CUDA_HOME is not None: 66 | return 67 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 68 | # in that case. 69 | warnings.warn( 70 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 71 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 72 | "only images whose names contain 'devel' will provide nvcc." 73 | ) 74 | 75 | 76 | def append_nvcc_threads(nvcc_extra_args): 77 | return nvcc_extra_args + ["--threads", "4"] 78 | 79 | 80 | cmdclass = {} 81 | ext_modules = [] 82 | 83 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 84 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 85 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 86 | 87 | # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h 88 | # See https://github.com/pytorch/pytorch/pull/70650 89 | generator_flag = [] 90 | torch_dir = torch.__path__[0] 91 | if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): 92 | generator_flag = ["-DOLD_GENERATOR_PATH"] 93 | 94 | check_if_cuda_home_none("flash_attn") 95 | # Check, if CUDA11 is installed for compute capability 8.0 96 | cc_flag = [] 97 | if CUDA_HOME is not None: 98 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 99 | if bare_metal_version < Version("11.6"): 100 | raise RuntimeError( 101 | "FlashAttention is only supported on CUDA 11.6 and above. " 102 | "Note: make sure nvcc has a supported version by running nvcc -V." 103 | ) 104 | # cc_flag.append("-gencode") 105 | # cc_flag.append("arch=compute_75,code=sm_75") 106 | # cc_flag.append("-gencode") 107 | # cc_flag.append("arch=compute_80,code=sm_80") 108 | cc_flag.append("-gencode") 109 | cc_flag.append("arch=compute_70,code=sm_70") 110 | 111 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 112 | # torch._C._GLIBCXX_USE_CXX11_ABI 113 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 114 | if FORCE_CXX11_ABI: 115 | torch._C._GLIBCXX_USE_CXX11_ABI = True 116 | ext_modules.append( 117 | CUDAExtension( 118 | name="flash_attn_v100_cuda", 119 | sources=[ 120 | "kernel/fused_mha_api.cpp", 121 | "kernel/fused_mha_kernel.cu" 122 | ], 123 | extra_compile_args={ 124 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 125 | "nvcc": append_nvcc_threads( 126 | [ 127 | "-O3", 128 | "-std=c++17", 129 | "-U__CUDA_NO_HALF_OPERATORS__", 130 | "-U__CUDA_NO_HALF_CONVERSIONS__", 131 | "-U__CUDA_NO_HALF2_OPERATORS__", 132 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 133 | "--expt-relaxed-constexpr", 134 | "--expt-extended-lambda", 135 | "--use_fast_math", 136 | # "--ptxas-options=-v", 137 | # "--ptxas-options=-O2", 138 | # "-lineinfo", 139 | ] 140 | + generator_flag 141 | + cc_flag 142 | ), 143 | }, 144 | include_dirs=[ 145 | Path(this_dir) / "include", 146 | "/home/user/cutlass/include", 147 | ], 148 | ) 149 | ) 150 | 151 | setup( 152 | name=PACKAGE_NAME, 153 | version="v0.0.1", 154 | packages=find_packages( 155 | exclude=( 156 | "build", 157 | "kernel", 158 | "include", 159 | "flash_attn_v100.egg-info", 160 | ) 161 | ), 162 | description="Flash Attention 2 for v100", 163 | ext_modules=ext_modules, 164 | cmdclass={"build_ext": BuildExtension}, 165 | python_requires=">=3.7", 166 | install_requires=[ 167 | "torch", 168 | "einops", 169 | "packaging", 170 | "ninja", 171 | ], 172 | ) 173 | -------------------------------------------------------------------------------- /kernel/fused_mha_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include "fused_mha.h" 8 | 9 | #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") 10 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | 13 | std::vector 14 | mha_fwd(at::Tensor &q, // batch_size x num_heads x seqlen x head_size 15 | const at::Tensor &k, // batch_size x num_heads x seqlen x head_size 16 | const at::Tensor &v, // batch_size x num_heads x seqlen x head_size 17 | c10::optional &out_, // batch_size x num_heads x seqlen x head_size 18 | const float softmax_scale, 19 | bool is_causal) 20 | { 21 | auto dprops = at::cuda::getCurrentDeviceProperties(); 22 | bool is_sm70 = dprops->major == 7 && dprops->minor == 0; 23 | TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs."); 24 | 25 | auto q_dtype = q.dtype(); 26 | TORCH_CHECK(q_dtype == torch::kFloat16, 27 | "This repo only supports fp16 data type"); 28 | TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); 29 | TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); 30 | 31 | CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); 32 | 33 | TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 34 | TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 35 | TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 36 | 37 | const auto sizes = q.sizes(); 38 | 39 | const int batch_size = sizes[0]; 40 | const int num_heads = sizes[2]; 41 | const int seqlen_q = sizes[1]; 42 | const int head_size = sizes[3]; 43 | TORCH_CHECK(batch_size > 0, "batch size must be postive"); 44 | TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture"); 45 | CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); 46 | CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size); 47 | CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size); 48 | 49 | at::Tensor out; 50 | if (out_.has_value()) { 51 | out = out_.value(); 52 | TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); 53 | CHECK_DEVICE(out); 54 | TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); 55 | CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); 56 | } else { 57 | out = torch::empty_like(q); 58 | } 59 | // Otherwise the kernel will be launched from cuda:0 device 60 | // Cast to char to avoid compiler warning about narrowing 61 | at::cuda::CUDAGuard device_guard{(char)q.get_device()}; 62 | 63 | auto opts = q.options(); 64 | 65 | auto softmax_sum = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); 66 | auto softmax_max = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); 67 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 68 | fused_mha_forward(q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(), 69 | batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream); 70 | 71 | return {out, softmax_max, softmax_sum}; 72 | } 73 | 74 | std::vector 75 | mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og 76 | const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 77 | const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size 78 | const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size 79 | const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size 80 | const at::Tensor &D, // batch_size x seqlen_q x num_heads x head_size 81 | const at::Tensor &softmax_sum, // b x h x seqlen_q 82 | const at::Tensor &softmax_max, // b x h x seqlen_q 83 | c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size 84 | c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size 85 | c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size 86 | const float softmax_scale, 87 | const bool is_causal) 88 | { 89 | auto dprops = at::cuda::getCurrentDeviceProperties(); 90 | bool is_sm70 = dprops->major == 7 && dprops->minor == 0; 91 | TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs."); 92 | 93 | auto q_dtype = q.dtype(); 94 | TORCH_CHECK(q_dtype == torch::kFloat16, 95 | "This repo only supports fp16 data type"); 96 | TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); 97 | TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); 98 | 99 | CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); 100 | CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_sum); CHECK_DEVICE(softmax_max); 101 | 102 | TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 103 | TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 104 | TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 105 | TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); 106 | TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); 107 | 108 | const auto sizes = q.sizes(); 109 | 110 | const int batch_size = sizes[0]; 111 | const int num_heads = sizes[2]; 112 | const int seqlen_q = sizes[1]; 113 | const int head_size = sizes[3]; 114 | TORCH_CHECK(batch_size > 0, "batch size must be postive"); 115 | TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture"); 116 | CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); 117 | CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size); 118 | CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size); 119 | CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); 120 | CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); 121 | CHECK_SHAPE(softmax_sum, batch_size, num_heads, seqlen_q); 122 | CHECK_SHAPE(softmax_max, batch_size, num_heads, seqlen_q); 123 | CHECK_SHAPE(D, batch_size, seqlen_q, num_heads); 124 | auto opts = q.options(); 125 | at::Tensor dq, dk, dv; 126 | if (dq_.has_value()) { 127 | dq = dq_.value(); 128 | TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); 129 | CHECK_DEVICE(dq); 130 | TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); 131 | CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); 132 | } else { 133 | // dq = torch::empty_like(q); 134 | dq = torch::zeros({batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kHalf)); 135 | } 136 | if (dk_.has_value()) { 137 | dk = dk_.value(); 138 | TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); 139 | CHECK_DEVICE(dk); 140 | TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); 141 | CHECK_SHAPE(dk, batch_size, seqlen_q, num_heads, head_size); 142 | } else { 143 | dk = torch::empty_like(k); 144 | } 145 | if (dv_.has_value()) { 146 | dv = dv_.value(); 147 | TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); 148 | CHECK_DEVICE(dv); 149 | TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); 150 | CHECK_SHAPE(dv, batch_size, seqlen_q, num_heads, head_size); 151 | } else { 152 | dv = torch::empty_like(k); 153 | } 154 | 155 | at::cuda::CUDAGuard device_guard{(char)q.get_device()}; 156 | 157 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 158 | fused_mha_backward(q.data_ptr(), k.data_ptr(), v.data_ptr(), 159 | out.data_ptr(), dout.data_ptr(), D.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(), 160 | dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream); 161 | return {dq, dk, dv}; 162 | } 163 | 164 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 165 | m.doc() = "FlashAttention"; 166 | m.def("fwd", &mha_fwd, "Forward pass"); 167 | m.def("bwd", &mha_bwd, "Backward pass"); 168 | } 169 | -------------------------------------------------------------------------------- /kernel/fused_mha_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // 从flash attention复制过来,对register中的数据类型做转换 9 | template 10 | inline __device__ auto convert_type(cute::Tensor const &tensor) { 11 | using namespace cute; 12 | using From_type = typename Engine::value_type; 13 | constexpr int numel = decltype(size(tensor))::value; 14 | cutlass::NumericArrayConverter convert_op; 15 | // HACK: this requires tensor to be "contiguous" 16 | auto frag = convert_op(*reinterpret_cast *>(tensor.data())); 17 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 18 | } 19 | 20 | template 21 | __global__ void fused_mha_forward_kernel(const void *query_ptr, const void *key_ptr, const void *value_ptr, void *output_ptr, void *max_ptr, void *sum_ptr, int head, int m, int n, int k, float scale, bool causal) 22 | { 23 | using namespace cute; 24 | using X = Underscore; 25 | using ElementType = typename Config::ElementType; 26 | using ComputeType = typename Config::ComputeType; 27 | 28 | using SmemLayoutQuery = typename Config::SmemLayoutQuery; 29 | using SmemLayoutKey = typename Config::SmemLayoutKey; 30 | using SmemLayoutValue = typename Config::SmemLayoutValue; 31 | using SmemLayoutValueTransposed = typename Config::SmemLayoutValueTransposed; 32 | using SmemLayoutAcc = typename Config::SmemLayoutAcc; 33 | using SmemLayoutOutput = typename Config::SmemLayoutOutput; 34 | using TiledMMA_TN = typename Config::MMA_TN; 35 | 36 | using Global2SharedCopyQuery = typename Config::Global2SharedCopyQuery; 37 | using Global2SharedCopyKey = typename Config::Global2SharedCopyKey; 38 | using Global2SharedCopyValue = typename Config::Global2SharedCopyValue; 39 | using Shared2RegisterCopyAcc = typename Config::Shared2RegisterCopyAcc; 40 | using Shared2RegisterCopyFp16Acc = typename Config::Shared2RegisterCopyFp16Acc; 41 | using Register2SharedCopyFp16Acc = typename Config::Register2SharedCopyFp16Acc; 42 | using Global2SharedCopyAtom = typename Config::Global2SharedCopyAtom; 43 | using SmemLayoutMax = typename Config::SmemLayoutMax; 44 | using SmemLayoutSum = typename Config::SmemLayoutSum; 45 | 46 | constexpr int kTileM = Config::kTileM; 47 | constexpr int kTileN = Config::kTileN; 48 | constexpr int kTileK = Config::kTileK; 49 | 50 | extern __shared__ ElementType shm_data[]; 51 | 52 | int idx = threadIdx.x; 53 | int b = blockIdx.y / head; 54 | int h = blockIdx.y % head; 55 | int row_m = blockIdx.x; 56 | 57 | Tensor global_query = make_tensor(make_gmem_ptr((ElementType *)query_ptr + b * head * m * k + h * k + row_m * kTileM * head * k), 58 | Shape, Int>{}, 59 | make_stride(head * k, Int<1>{})); 60 | Tensor global_output = make_tensor(make_gmem_ptr((ElementType *)output_ptr + b * head * m * k + h * k + row_m * kTileM * head * k), Shape, Int>{}, 61 | make_stride(head * k, Int<1>{})); 62 | Tensor global_max = make_tensor(make_gmem_ptr((ComputeType*)max_ptr + b * head * m + h * m + row_m * kTileM), 63 | Shape, Int<1>>{}, 64 | make_stride(1, Int<1>{})); 65 | Tensor global_sum = make_tensor(make_gmem_ptr((ComputeType*)sum_ptr + b * head * m + h * m + row_m * kTileM), 66 | Shape, Int<1>>{}, 67 | make_stride(1, Int<1>{})); 68 | ComputeType *max_shm = (ComputeType*)((char*)shm_data + Config::kShmSizeQKVO); 69 | auto shared_max = make_tensor(make_smem_ptr(max_shm), SmemLayoutMax{}); 70 | 71 | if (idx <32) 72 | shared_max(idx, 0) = -INFINITY;//global_max(idx, 0); 73 | __syncthreads(); 74 | ComputeType *old_max_shm = max_shm + cute::cosize(SmemLayoutMax{}); 75 | auto shared_old_max = make_tensor(make_smem_ptr(old_max_shm), SmemLayoutMax{}); 76 | if (idx < 32) 77 | shared_old_max(idx, 0) = -INFINITY; 78 | __syncthreads(); 79 | ComputeType *sum_shm = old_max_shm + cute::cosize(SmemLayoutMax{}); 80 | auto shared_sum = make_tensor(make_smem_ptr(sum_shm), SmemLayoutSum{}); 81 | clear(shared_sum); 82 | ElementType *query_shm = shm_data; 83 | auto shared_query = make_tensor(make_smem_ptr(query_shm), SmemLayoutQuery{}); // (kTileM, kTileK) 84 | Global2SharedCopyQuery global_2_shared_query_copy_tile; 85 | auto global_2_shared_query_copy_thread = global_2_shared_query_copy_tile.get_slice(idx); 86 | auto global_query_thread_copy_tensor = global_2_shared_query_copy_thread.partition_S(global_query); // (CPY, CPY_M, CPY_K) 87 | auto shared_query_thread_copy_tensor = global_2_shared_query_copy_thread.partition_D(shared_query); // (CPY, CPY_M, CPY_K) 88 | float4* shared_ptr = reinterpret_cast(shm_data); 89 | float4* global_ptr = reinterpret_cast((ElementType *)query_ptr + b * head * m * k + h * k + row_m * kTileM * head * k); 90 | for (int i = 0; i < 4; i++) { 91 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 92 | } 93 | __syncthreads(); 94 | TiledMMA_TN tiled_mma_tn; 95 | auto thr_mma_tn = tiled_mma_tn.get_slice(idx); 96 | auto mma_thread_query_register_tensor = thr_mma_tn.partition_fragment_A(global_query); // (MMA, MMA_M, MMA_K) 97 | auto mma_thread_query_shared_tensor = thr_mma_tn.partition_A(shared_query); 98 | cute::copy(mma_thread_query_shared_tensor, mma_thread_query_register_tensor); 99 | __syncthreads(); 100 | auto shared_key = make_tensor(make_smem_ptr(query_shm), SmemLayoutKey{}); // (kTileN, kTileK) 101 | 102 | Tensor global_key = make_tensor(make_gmem_ptr((ElementType *)key_ptr + b * head * n * k + h * k), 103 | Shape, Int>{}, 104 | make_stride(head * k, Int<1>{})); 105 | auto mma_thread_key_register_tensor = thr_mma_tn.partition_fragment_B(global_key); // (MMA, MMA_N, MMA_K) 106 | auto mma_thread_key_shared_tensor = thr_mma_tn.partition_B(shared_key); 107 | 108 | ElementType *acc_shm = (ElementType*)shm_data; 109 | auto shared_acc = make_tensor(make_smem_ptr(acc_shm), SmemLayoutAcc{}); 110 | auto mma_thread_acc_register_tensor = thr_mma_tn.partition_fragment_C(shared_acc); 111 | auto mma_thread_acc_shared_tensor = thr_mma_tn.partition_C(shared_acc); 112 | Shared2RegisterCopyAcc shared_2_register_acc_copy_tile; 113 | auto shared_2_register_acc_copy_thread = shared_2_register_acc_copy_tile.get_slice(idx); 114 | auto shared_2_register_acc_thread_shared_tensor = shared_2_register_acc_copy_thread.partition_D(shared_acc); // (CPY, CPY_M, CPY_K) 115 | auto shared_2_register_acc_thread_register_tensor = make_fragment_like(shared_2_register_acc_thread_shared_tensor); 116 | ElementType *fp16_acc_shm = shm_data; 117 | auto shared_fp16_acc = make_tensor(make_smem_ptr(fp16_acc_shm), SmemLayoutAcc{}); 118 | Shared2RegisterCopyFp16Acc shared_2_register_fp16_acc_copy_tile; 119 | auto shared_2_register_fp16_acc_copy_thread = shared_2_register_fp16_acc_copy_tile.get_slice(idx); 120 | auto register_2_shared_fp16_acc_thread_shared_tensor = shared_2_register_fp16_acc_copy_thread.partition_D(shared_fp16_acc); // (CPY, CPY_M, CPY_K) 121 | auto mma_thread_acc_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_fp16_acc); 122 | auto mma_thread_acc_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_fp16_acc); 123 | ElementType *value_shm = shm_data; 124 | auto shared_value = make_tensor(make_smem_ptr(value_shm), SmemLayoutValue{}); 125 | auto shared_value_transposed = make_tensor(make_smem_ptr(value_shm), SmemLayoutValueTransposed{}); 126 | auto mma_thread_value_register_tensor = thr_mma_tn.partition_fragment_B(shared_value_transposed); // (MMA, MMA_N, MMA_K) 127 | auto mma_thread_value_shared_tensor = thr_mma_tn.partition_B(shared_value_transposed); 128 | 129 | ElementType *output_shm = (ElementType*)shm_data; 130 | auto shared_output = make_tensor(make_smem_ptr(output_shm), SmemLayoutOutput{}); 131 | auto mma_thread_output_register_tensor = thr_mma_tn.partition_fragment_C(global_output); // (MMA, MMA_M, MMA_K) 132 | auto mma_thread_output_shared_tensor = thr_mma_tn.partition_C(shared_output); 133 | auto mma_thread_output_global_tensor = thr_mma_tn.partition_C(global_output); 134 | clear(mma_thread_output_register_tensor); 135 | ElementType *o_scale_shm = (ElementType *)(value_shm + cute::cosize(SmemLayoutValue{})); 136 | auto shared_o_scale = make_tensor(make_smem_ptr(o_scale_shm), SmemLayoutOutput{}); 137 | auto mma_thread_o_scale_shared_tensor = thr_mma_tn.partition_C(shared_o_scale); 138 | ElementType *sum_scale_shm = (ElementType *)shm_data; 139 | auto shared_sum_scale = make_tensor(make_smem_ptr(sum_scale_shm), SmemLayoutOutput{}); 140 | auto mma_thread_sum_scale_shared_tensor = thr_mma_tn.partition_C(shared_sum_scale); 141 | int end_tile = n / kTileN - 1; 142 | if (causal) { 143 | end_tile = row_m; 144 | } 145 | #pragma unroll 146 | for (int n_tile = 0; n_tile < end_tile; n_tile++) { 147 | clear(mma_thread_acc_register_tensor); 148 | Tensor global_key = make_tensor(make_gmem_ptr((ElementType *)key_ptr + b * head * n * k + h * k + n_tile * kTileN * head * k), 149 | Shape, Int>{}, 150 | make_stride(head * k, Int<1>{})); 151 | auto mma_thread_key_global_tensor = thr_mma_tn.partition_B(global_key); 152 | cute::copy(mma_thread_key_global_tensor, mma_thread_key_register_tensor); 153 | cute::gemm(tiled_mma_tn, mma_thread_acc_register_tensor, 154 | mma_thread_query_register_tensor, 155 | mma_thread_key_register_tensor, 156 | mma_thread_acc_register_tensor); 157 | cute::copy(mma_thread_acc_register_tensor, mma_thread_acc_shared_tensor); 158 | __syncthreads(); 159 | if (idx < 32) { 160 | ComputeType thread_max = shared_acc(idx, 0); 161 | #pragma unroll 162 | for (int j = 1; j < kTileN; j++) { 163 | ComputeType tmp = shared_acc(idx, j); 164 | thread_max = thread_max < tmp ? tmp : thread_max; 165 | } 166 | thread_max *= scale; 167 | ComputeType old_thread_max = shared_max(idx, 0); 168 | ComputeType new_thread_max = old_thread_max < thread_max ? thread_max : old_thread_max; 169 | ComputeType block_thread_sum = 0; 170 | #pragma unroll 171 | for (int j = 0; j < kTileN; j++) { 172 | ComputeType thread_acc = shared_acc(idx, j); 173 | ComputeType new_thread_acc = __expf(thread_acc * scale - new_thread_max); 174 | block_thread_sum += new_thread_acc; 175 | shared_acc(idx, j) = new_thread_acc; 176 | } 177 | shared_sum(idx, 0) = __expf(old_thread_max - new_thread_max) * shared_sum(idx, 0) + block_thread_sum; 178 | shared_max(idx, 0) = new_thread_max; 179 | // shared_old_max(idx, 0) = __expf(old_thread_max - new_thread_max); 180 | shared_old_max(idx, 0) = __expf(old_thread_max - new_thread_max); 181 | } 182 | __syncthreads(); 183 | #pragma unroll 184 | for (int i = 0; i < kTileN; i++) { 185 | shared_o_scale(i, idx) = shared_old_max(i, 0); 186 | } 187 | cute::copy(mma_thread_acc_shared_tensor_a_matrix, mma_thread_acc_register_tensor_a_matrix); 188 | __syncthreads(); 189 | Tensor global_value = make_tensor(make_gmem_ptr((ElementType *)value_ptr + b * head * n * k + h * k + n_tile * kTileN * head * k), 190 | Shape, Int>{}, 191 | make_stride(head * k, Int<1>{})); 192 | 193 | Global2SharedCopyValue global_2_shared_value_copy_tile; 194 | auto global_2_shared_value_copy_thread = global_2_shared_value_copy_tile.get_slice(idx); 195 | auto global_value_thread_copy_tensor = global_2_shared_value_copy_thread.partition_S(global_value); // (CPY, CPY_N, CPY_K) 196 | auto shared_value_thread_copy_tensor = global_2_shared_value_copy_thread.partition_D(shared_value); // (CPY, CPY_N , CPY_K) 197 | shared_ptr = reinterpret_cast(shm_data); 198 | global_ptr = reinterpret_cast((ElementType *)value_ptr + b * head * n * k + h * k + n_tile * kTileN * head * k); 199 | for (int i = 0; i < 4; i++) { 200 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 201 | } 202 | __syncthreads(); 203 | cute::copy(mma_thread_value_shared_tensor, mma_thread_value_register_tensor); 204 | #pragma unroll 205 | for (int i = 0; i < size<0>(mma_thread_output_register_tensor); i++) { 206 | #pragma unroll 207 | for (int j = 0; j < size<1>(mma_thread_output_register_tensor); j++) { 208 | #pragma unroll 209 | for (int kk = 0; kk < size<2>(mma_thread_output_register_tensor); kk++) { 210 | mma_thread_output_register_tensor(i, j, kk) *= mma_thread_o_scale_shared_tensor(i, j, kk); 211 | } 212 | } 213 | } 214 | cute::gemm(thr_mma_tn, mma_thread_output_register_tensor, 215 | mma_thread_acc_register_tensor_a_matrix, 216 | mma_thread_value_register_tensor, 217 | mma_thread_output_register_tensor); 218 | } 219 | 220 | clear(mma_thread_acc_register_tensor); 221 | global_key = make_tensor(make_gmem_ptr((ElementType *)key_ptr + b * head * n * k + h * k + end_tile * kTileN * head * k), 222 | Shape, Int>{}, 223 | make_stride(head * k, Int<1>{})); 224 | auto mma_thread_key_global_tensor = thr_mma_tn.partition_B(global_key); 225 | cute::copy(mma_thread_key_global_tensor, mma_thread_key_register_tensor); 226 | cute::gemm(tiled_mma_tn, mma_thread_acc_register_tensor, 227 | mma_thread_query_register_tensor, 228 | mma_thread_key_register_tensor, 229 | mma_thread_acc_register_tensor); 230 | cute::copy(mma_thread_acc_register_tensor, mma_thread_acc_shared_tensor); 231 | __syncthreads(); 232 | #pragma unroll 233 | for (int i = 0; i < kTileN; i+=4) { 234 | shared_acc(idx / kTileN + i, idx % kTileN) = (idx / kTileN + i >= idx % kTileN ? shared_acc(idx / kTileN + i, idx % kTileN) : -INFINITY); 235 | } 236 | __syncthreads(); 237 | if (idx < 32) { 238 | ComputeType thread_max = shared_acc(idx, 0); 239 | #pragma unroll 240 | for (int j = 1; j < kTileN; j++) { 241 | ComputeType tmp = shared_acc(idx, j); 242 | thread_max = thread_max < tmp ? tmp : thread_max; 243 | } 244 | thread_max *= scale; 245 | ComputeType old_thread_max = shared_max(idx, 0); 246 | ComputeType new_thread_max = old_thread_max < thread_max ? thread_max : old_thread_max; 247 | ComputeType block_thread_sum = 0; 248 | #pragma unroll 249 | for (int j = 0; j < kTileN; j++) { 250 | ComputeType thread_acc = shared_acc(idx, j); 251 | ComputeType new_thread_acc = __expf(thread_acc * scale - new_thread_max); 252 | block_thread_sum += new_thread_acc; 253 | shared_acc(idx, j) = new_thread_acc; 254 | } 255 | shared_sum(idx, 0) = __expf(old_thread_max - new_thread_max) * shared_sum(idx, 0) + block_thread_sum; 256 | shared_max(idx, 0) = new_thread_max; 257 | shared_old_max(idx, 0) = __expf(old_thread_max - new_thread_max); 258 | } 259 | __syncthreads(); 260 | #pragma unroll 261 | for (int i = 0; i < kTileN; i++) { 262 | shared_o_scale(i, idx) = shared_old_max(i, 0); 263 | } 264 | 265 | cute::copy(mma_thread_acc_shared_tensor_a_matrix, mma_thread_acc_register_tensor_a_matrix); 266 | __syncthreads(); 267 | Tensor global_value = make_tensor(make_gmem_ptr((ElementType *)value_ptr + b * head * n * k + h * k + end_tile * kTileN * head * k), 268 | Shape, Int>{}, 269 | make_stride(head * k, Int<1>{})); 270 | 271 | Global2SharedCopyValue global_2_shared_value_copy_tile; 272 | auto global_2_shared_value_copy_thread = global_2_shared_value_copy_tile.get_slice(idx); 273 | auto global_value_thread_copy_tensor = global_2_shared_value_copy_thread.partition_S(global_value); // (CPY, CPY_N, CPY_K) 274 | auto shared_value_thread_copy_tensor = global_2_shared_value_copy_thread.partition_D(shared_value); // (CPY, CPY_N , CPY_K) 275 | shared_ptr = reinterpret_cast(shm_data); 276 | global_ptr = reinterpret_cast((ElementType *)value_ptr + b * head * n * k + h * k + end_tile * kTileN * head * k); 277 | for (int i = 0; i < 4; i++) { 278 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 279 | } 280 | __syncthreads(); 281 | cute::copy(mma_thread_value_shared_tensor, mma_thread_value_register_tensor); 282 | __syncthreads(); 283 | #pragma unroll 284 | for (int i = 0; i < size<0>(mma_thread_output_register_tensor); i++) { 285 | #pragma unroll 286 | for (int j = 0; j < size<1>(mma_thread_output_register_tensor); j++) { 287 | #pragma unroll 288 | for (int kk = 0; kk < size<2>(mma_thread_output_register_tensor); kk++) { 289 | mma_thread_output_register_tensor(i, j, kk) *= mma_thread_o_scale_shared_tensor(i, j, kk); 290 | } 291 | } 292 | } 293 | __syncthreads(); 294 | cute::gemm(thr_mma_tn, mma_thread_output_register_tensor, 295 | mma_thread_acc_register_tensor_a_matrix, 296 | mma_thread_value_register_tensor, 297 | mma_thread_output_register_tensor); 298 | if (idx < 32) { 299 | global_max(idx, 0) = shared_max(idx, 0); 300 | } 301 | else if (idx < 64) { 302 | global_sum(idx - 32, 0) = shared_sum(idx - 32, 0); 303 | } 304 | #pragma unroll 305 | for (int i = 0; i < kTileN; i++) { 306 | shared_sum_scale(i, idx) = 1.f / shared_sum(i, 0); 307 | } 308 | __syncthreads(); 309 | output_shm = (ElementType*)(sum_scale_shm + cute::cosize(SmemLayoutOutput{})); 310 | shared_output = make_tensor(make_smem_ptr(output_shm), SmemLayoutOutput{}); 311 | mma_thread_output_shared_tensor = thr_mma_tn.partition_C(shared_output); 312 | cute::copy(mma_thread_output_register_tensor, mma_thread_output_shared_tensor); 313 | __syncthreads(); 314 | for (int i = 0; i < kTileM; i++) { 315 | shared_output(i, idx) = shared_output(i, idx) * shared_sum_scale(i, idx); 316 | } 317 | __syncthreads(); 318 | shared_ptr = reinterpret_cast(output_shm); 319 | global_ptr = reinterpret_cast((ElementType *)output_ptr + b * head * m * k + h * k + row_m * kTileM * head * k); 320 | for (int i = 0; i < 4; i++) { 321 | global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16] = shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16]; 322 | } 323 | } 324 | 325 | template 326 | __global__ void 327 | fused_mha_backward_kernel(const void *query_ptr, const void *key_ptr, const void *value_ptr, 328 | void *output_ptr, void *d_output_ptr, void *d_ptr, void *max_ptr, void *sum_ptr, 329 | void *d_query_ptr, void *d_key_ptr, void *d_value_ptr, int head, int m, int n, int k, float scale, bool causal) 330 | { 331 | using namespace cute; 332 | using X = Underscore; 333 | using ElementType = typename Config::ElementType; 334 | using ComputeType = typename Config::ComputeType; 335 | 336 | using SmemLayoutQuery = typename Config::SmemLayoutQuery; 337 | using SmemLayoutQueryTransposed = typename Config::SmemLayoutQueryTransposed; 338 | using SmemLayoutKey = typename Config::SmemLayoutKey; 339 | using SmemLayoutKeyTransposed = typename Config::SmemLayoutKeyTransposed; 340 | using SmemLayoutValue = typename Config::SmemLayoutValue; 341 | using SmemLayoutValueTransposed = typename Config::SmemLayoutValueTransposed; 342 | using SmemLayoutAcc = typename Config::SmemLayoutAcc; 343 | using SmemLayoutAccTransposed = typename Config::SmemLayoutAccTransposed; 344 | using SmemLayoutOutput = typename Config::SmemLayoutOutput; 345 | using SmemLayoutDOutput = typename Config::SmemLayoutDOutput; 346 | using SmemLayoutDOutputTransposed = typename Config::SmemLayoutDOutputTransposed; 347 | using TiledMMA_TN = typename Config::MMA_TN; 348 | 349 | using Global2SharedCopyQuery = typename Config::Global2SharedCopyQuery; 350 | using Global2SharedCopyKey = typename Config::Global2SharedCopyKey; 351 | using Global2SharedCopyValue = typename Config::Global2SharedCopyValue; 352 | using Global2SharedCopyOutput = typename Config::Global2SharedCopyOutput; 353 | using Global2SharedCopyDOutput = typename Config::Global2SharedCopyDOutput; 354 | using Shared2RegisterCopyAcc = typename Config::Shared2RegisterCopyAcc; 355 | using Shared2RegisterCopyFp16Acc = typename Config::Shared2RegisterCopyFp16Acc; 356 | using Register2SharedCopyFp16Acc = typename Config::Register2SharedCopyFp16Acc; 357 | using Global2SharedCopyAtom = typename Config::Global2SharedCopyAtom; 358 | using SmemLayoutMax = typename Config::SmemLayoutMax; 359 | using SmemLayoutSum = typename Config::SmemLayoutSum; 360 | 361 | constexpr int kTileM = Config::kTileM; 362 | constexpr int kTileN = Config::kTileN; 363 | constexpr int kTileK = Config::kTileK; 364 | 365 | extern __shared__ ElementType shm_data[]; 366 | 367 | ElementType *qkvodo_shm = shm_data; 368 | ComputeType *max_shm = (ComputeType*)(qkvodo_shm + 2 * cosize(SmemLayoutQuery{})); 369 | auto shared_max = make_tensor(make_smem_ptr(max_shm), SmemLayoutMax{}); 370 | 371 | ComputeType *sum_shm = (ComputeType*)(max_shm + cosize(SmemLayoutMax{})); 372 | auto shared_sum = make_tensor(make_smem_ptr(sum_shm), SmemLayoutSum{}); 373 | 374 | int idx = threadIdx.x; 375 | int b = blockIdx.y / head; 376 | int h = blockIdx.y % head; 377 | int col_n = blockIdx.x; 378 | TiledMMA_TN tiled_mma_tn; 379 | auto thr_mma_tn = tiled_mma_tn.get_slice(idx); 380 | 381 | Tensor global_d_key = make_tensor(make_gmem_ptr((ElementType *)d_key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k), 382 | Shape, Int>{}, 383 | make_stride(head * k, Int<1>{})); 384 | auto mma_thread_d_key_register_tensor = thr_mma_tn.partition_fragment_C(global_d_key); 385 | auto mma_thread_d_key_global_tensor = thr_mma_tn.partition_C(global_d_key); 386 | clear(mma_thread_d_key_register_tensor); 387 | Tensor global_d_value = make_tensor(make_gmem_ptr((ElementType *)d_value_ptr + b * head * n * k + h * k + col_n * kTileN * head * k), 388 | Shape, Int>{}, 389 | make_stride(head * k, Int<1>{})); 390 | auto mma_thread_d_value_global_tensor = thr_mma_tn.partition_C(global_d_value); 391 | auto mma_thread_d_value_register_tensor = thr_mma_tn.partition_fragment_C(global_d_value); 392 | clear(mma_thread_d_value_register_tensor); 393 | int begin_m_tile = 0; 394 | ElementType* value_shm = shm_data; 395 | auto shared_value = make_tensor(make_smem_ptr(value_shm), SmemLayoutValue{}); 396 | float4* shared_ptr1 = reinterpret_cast(value_shm); 397 | float4* global_ptr1 = reinterpret_cast((ElementType *)value_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 398 | #pragma unroll 399 | for (int i = 0; i < 4; i++) { 400 | shared_ptr1[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr1[(idx / 16 + i * 8) * head * 16 + idx % 16]; 401 | } 402 | __syncthreads(); 403 | auto mma_thread_value_register_tensor = thr_mma_tn.partition_fragment_B(shared_value); // (MMA, MMA_N, MMA_K) 404 | auto mma_thread_value_shared_tensor = thr_mma_tn.partition_B(shared_value); 405 | cute::copy(mma_thread_value_shared_tensor, mma_thread_value_register_tensor); 406 | __syncthreads(); 407 | if (causal) { 408 | begin_m_tile = col_n; 409 | } 410 | { 411 | ElementType* key_shm = shm_data; 412 | Tensor global_key = make_tensor(make_gmem_ptr((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k), 413 | Shape, Int>{}, 414 | make_stride(head * k, Int<1>{})); 415 | float4* shared_ptr = reinterpret_cast(key_shm); 416 | float4* global_ptr = reinterpret_cast((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 417 | auto mma_thread_key_register_tensor = thr_mma_tn.partition_fragment_B(global_key); // (MMA, MMA_N, MMA_K) 418 | ElementType* query_shm = key_shm + cosize(SmemLayoutKey{}); 419 | Tensor global_query = make_tensor(make_gmem_ptr((ElementType *)query_ptr + b * head * m * k + h * k + begin_m_tile * kTileM * head * k), 420 | Shape, Int>{}, 421 | make_stride(head * k, Int<1>{})); 422 | auto mma_thread_query_register_tensor = thr_mma_tn.partition_fragment_A(global_query); // (MMA, MMA_M, MMA_K) 423 | ComputeType *acc_shm = (ComputeType*)(sum_shm + cosize(SmemLayoutSum{})); 424 | auto shared_acc = make_tensor(make_smem_ptr(acc_shm), SmemLayoutAcc{}); 425 | auto mma_thread_acc_shared_tensor = thr_mma_tn.partition_C(shared_acc); 426 | auto mma_thread_acc_register_tensor = thr_mma_tn.partition_fragment_C(shared_acc); 427 | clear(mma_thread_acc_register_tensor); 428 | auto mma_thread_key_global_tensor = thr_mma_tn.partition_B(global_key); 429 | auto mma_thread_query_global_tensor = thr_mma_tn.partition_A(global_query); 430 | cute::copy(mma_thread_key_global_tensor, mma_thread_key_register_tensor); 431 | cute::copy(mma_thread_query_global_tensor, mma_thread_query_register_tensor); 432 | cute::gemm(tiled_mma_tn, mma_thread_acc_register_tensor, 433 | mma_thread_query_register_tensor, 434 | mma_thread_key_register_tensor, 435 | mma_thread_acc_register_tensor); 436 | cute::copy(mma_thread_acc_register_tensor, mma_thread_acc_shared_tensor); 437 | __syncthreads(); 438 | if (causal) { 439 | if (idx < 32) { 440 | #pragma unroll 441 | for (int j = 0; j < kTileN; j++) { 442 | shared_acc(idx, j) = (idx >= j ? shared_acc(idx, j) : -INFINITY); 443 | } 444 | } 445 | __syncthreads(); 446 | } 447 | Tensor global_max = make_tensor(make_gmem_ptr((ComputeType *)max_ptr + b * head * m + h * m + begin_m_tile * kTileM), 448 | Shape, Int<1>>{}, 449 | make_stride(1, Int<1>{})); 450 | Tensor global_sum = make_tensor(make_gmem_ptr((ComputeType *)sum_ptr + b * head * m + h * m + begin_m_tile * kTileM), 451 | Shape, Int<1>>{}, 452 | make_stride(1, Int<1>{})); 453 | #pragma unroll 454 | for (int i = 0; i < kTileM; i += 4) { 455 | auto thread_max = global_max(idx / kTileN + i, 0); 456 | auto thread_sum = 1.f / global_sum(idx / kTileN + i, 0); 457 | shared_acc(idx / kTileN + i, idx % kTileN) = __expf(scale * shared_acc(idx / kTileN + i, idx % kTileN) - thread_max) * thread_sum; 458 | } 459 | __syncthreads(); 460 | 461 | ElementType *fp16_acc_shm = (ElementType*)(acc_shm + cosize(SmemLayoutAcc{})); 462 | ElementType* d_output_shm = key_shm; 463 | auto shared_d_output = make_tensor(make_smem_ptr(d_output_shm), SmemLayoutDOutput{}); 464 | auto shared_d_output_transposed = make_tensor(make_smem_ptr(d_output_shm), SmemLayoutDOutputTransposed{}); 465 | shared_ptr = reinterpret_cast(d_output_shm); 466 | global_ptr = reinterpret_cast((ElementType *)d_output_ptr + b * head * m * k + h * k + begin_m_tile * kTileM * head * k); 467 | #pragma unroll 468 | for (int i = 0; i < 4; i++) { 469 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 470 | } 471 | __syncthreads(); 472 | 473 | auto mma_thread_d_output_shared_tensor_b_matrix = thr_mma_tn.partition_B(shared_d_output_transposed); 474 | auto mma_thread_d_output_register_tensor_b_matrix = thr_mma_tn.partition_fragment_B(shared_d_output_transposed); // (MMA, MMA_M, MMA_K) 475 | cute::copy(mma_thread_d_output_shared_tensor_b_matrix, mma_thread_d_output_register_tensor_b_matrix); 476 | __syncthreads(); 477 | auto shared_acc_transposed = make_tensor(make_smem_ptr(acc_shm), SmemLayoutAccTransposed{}); 478 | auto mma_thread_acc_transposed_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_acc_transposed); 479 | auto mma_thread_acc_transposed_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_acc_transposed); 480 | cute::copy(mma_thread_acc_transposed_shared_tensor_a_matrix, mma_thread_acc_transposed_register_tensor_a_matrix); 481 | cute::gemm(thr_mma_tn, mma_thread_d_value_register_tensor, 482 | mma_thread_acc_transposed_register_tensor_a_matrix, 483 | mma_thread_d_output_register_tensor_b_matrix, 484 | mma_thread_d_value_register_tensor); 485 | auto mma_thread_d_output_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_d_output); // (MMA, MMA_M, MMA_K) 486 | auto mma_thread_d_output_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_d_output); 487 | cute::copy(mma_thread_d_output_shared_tensor_a_matrix, mma_thread_d_output_register_tensor_a_matrix); 488 | __syncthreads(); 489 | ComputeType *d_p_shm = (ComputeType*)(fp16_acc_shm); 490 | auto shared_d_p = make_tensor(make_smem_ptr(d_p_shm), SmemLayoutAcc{}); 491 | auto mma_thread_d_p_register_tensor = thr_mma_tn.partition_fragment_C(shared_d_p); 492 | auto mma_thread_d_p_shared_tensor = thr_mma_tn.partition_C(shared_d_p); 493 | clear(mma_thread_d_p_register_tensor); 494 | cute::gemm(thr_mma_tn, mma_thread_d_p_register_tensor, 495 | mma_thread_d_output_register_tensor_a_matrix, 496 | mma_thread_value_register_tensor, 497 | mma_thread_d_p_register_tensor); 498 | cute::copy(mma_thread_d_p_register_tensor, mma_thread_d_p_shared_tensor); 499 | __syncthreads(); 500 | Tensor global_d = make_tensor(make_gmem_ptr((ElementType *)d_ptr + b * head * m + h + begin_m_tile * kTileM * head), 501 | Shape, Int<1>>{}, 502 | make_stride(head, Int<1>{})); 503 | #pragma unroll 504 | for (int i = 0; i < kTileM; i+=4) { 505 | shared_acc(i + idx / 32, idx % 32) *= (shared_d_p(i + idx / 32, idx % 32) - global_d(i + idx / 32, 0)) * scale; 506 | } 507 | 508 | __syncthreads(); 509 | auto mma_thread_d_s_fp32_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_acc); 510 | auto mma_thread_d_s_fp16_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_acc); 511 | cute::copy(mma_thread_d_s_fp32_shared_tensor_a_matrix, mma_thread_d_s_fp16_register_tensor_a_matrix); 512 | Tensor global_d_query = make_tensor(make_gmem_ptr((ElementType *)d_query_ptr + b * head * m * k + h * k + begin_m_tile * kTileM * head * k), 513 | Shape, Int>{}, 514 | make_stride(head * k, Int<1>{})); 515 | auto mma_thread_d_query_register_tensor = thr_mma_tn.partition_fragment_C(global_d_query); 516 | auto mma_thread_d_query_global_tensor = thr_mma_tn.partition_C(global_d_query); 517 | clear(mma_thread_d_query_register_tensor); 518 | auto shared_key_transposed = make_tensor(make_smem_ptr(key_shm), SmemLayoutKeyTransposed{}); 519 | shared_ptr = reinterpret_cast(key_shm); 520 | global_ptr = reinterpret_cast((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 521 | #pragma unroll 522 | for (int i = 0; i < 4; i++) { 523 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 524 | } 525 | __syncthreads(); 526 | auto mma_thread_transposed_key_register_tensor = thr_mma_tn.partition_fragment_B(shared_key_transposed); // (MMA, MMA_N, MMA_K) 527 | auto mma_thread_transposed_key_shared_tensor = thr_mma_tn.partition_B(shared_key_transposed); 528 | cute::copy(mma_thread_transposed_key_shared_tensor, mma_thread_transposed_key_register_tensor); 529 | __syncthreads(); 530 | cute::gemm(thr_mma_tn, mma_thread_d_query_register_tensor, 531 | mma_thread_d_s_fp16_register_tensor_a_matrix, 532 | mma_thread_transposed_key_register_tensor, 533 | mma_thread_d_query_register_tensor); 534 | auto shared_d_query = make_tensor(make_smem_ptr((ComputeType*)shm_data), SmemLayoutQuery{}); // (kTileM, kTileK) 535 | auto mma_thread_d_query_shared_tensor = thr_mma_tn.partition_C(shared_d_query); 536 | cute::copy(mma_thread_d_query_register_tensor, mma_thread_d_query_shared_tensor); 537 | __syncthreads(); 538 | cute::half_t d_query_register_half2[2]; 539 | #pragma unroll 540 | for (int i = 0; i < kTileM; i+=2) { 541 | d_query_register_half2[0] = (cute::half_t)shared_d_query(i + idx / 64, (idx % 64) * 2); 542 | d_query_register_half2[1] = (cute::half_t)shared_d_query(i + idx / 64, (idx % 64) * 2 + 1); 543 | atomicAdd(reinterpret_cast<__half2*>(&global_d_query(i + idx / 64, (idx % 64) * 2)), *(reinterpret_cast<__half2*>(d_query_register_half2))); 544 | } 545 | cute::copy(mma_thread_acc_transposed_shared_tensor_a_matrix, mma_thread_acc_transposed_register_tensor_a_matrix); 546 | __syncthreads(); 547 | auto shared_query_transposed = make_tensor(make_smem_ptr(query_shm), SmemLayoutQueryTransposed{}); 548 | shared_ptr = reinterpret_cast(query_shm); 549 | global_ptr = reinterpret_cast((ElementType *)query_ptr + b * head * m * k + h * k + begin_m_tile * kTileM * head * k); 550 | #pragma unroll 551 | for (int i = 0; i < 4; i++) { 552 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 553 | } 554 | __syncthreads(); 555 | auto mma_thread_query_register_tensor_b_matrix = thr_mma_tn.partition_fragment_B(shared_query_transposed); // (MMA, MMA_M, MMA_K) 556 | auto mma_thread_query_shared_tensor_b_matrix = thr_mma_tn.partition_B(shared_query_transposed); // (MMA, MMA_M, MMA_K) 557 | cute::copy(mma_thread_query_shared_tensor_b_matrix, mma_thread_query_register_tensor_b_matrix); 558 | cute::gemm(thr_mma_tn, mma_thread_d_key_register_tensor, 559 | mma_thread_acc_transposed_register_tensor_a_matrix, 560 | mma_thread_query_register_tensor_b_matrix, 561 | mma_thread_d_key_register_tensor); 562 | } 563 | #pragma unroll 564 | for (int m_tile = begin_m_tile + 1; m_tile < m / kTileM; m_tile++) { 565 | ElementType* key_shm = shm_data; 566 | Tensor global_key = make_tensor(make_gmem_ptr((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k), 567 | Shape, Int>{}, 568 | make_stride(head * k, Int<1>{})); 569 | float4* shared_ptr = reinterpret_cast(key_shm); 570 | float4* global_ptr = reinterpret_cast((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 571 | ElementType* query_shm = key_shm + cosize(SmemLayoutKey{}); 572 | Tensor global_query = make_tensor(make_gmem_ptr((ElementType *)query_ptr + b * head * m * k + h * k + m_tile * kTileM * head * k), 573 | Shape, Int>{}, 574 | make_stride(head * k, Int<1>{})); 575 | ComputeType *acc_shm = (ComputeType*)(sum_shm + cute::cosize(SmemLayoutSum{})); 576 | auto shared_acc = make_tensor(make_smem_ptr(acc_shm), SmemLayoutAcc{}); 577 | auto mma_thread_acc_shared_tensor = thr_mma_tn.partition_C(shared_acc); 578 | auto mma_thread_acc_register_tensor = thr_mma_tn.partition_fragment_C(shared_acc); 579 | clear(mma_thread_acc_register_tensor); 580 | auto mma_thread_key_global_tensor = thr_mma_tn.partition_B(global_key); 581 | auto mma_thread_query_global_tensor = thr_mma_tn.partition_A(global_query); 582 | auto mma_thread_key_register_tensor = thr_mma_tn.partition_fragment_B(global_key); 583 | auto mma_thread_query_register_tensor = thr_mma_tn.partition_fragment_A(global_query); 584 | cute::copy(mma_thread_query_global_tensor, mma_thread_query_register_tensor); 585 | cute::copy(mma_thread_key_global_tensor, mma_thread_key_register_tensor); 586 | cute::gemm(thr_mma_tn, mma_thread_acc_register_tensor, 587 | mma_thread_query_register_tensor, 588 | mma_thread_key_register_tensor, 589 | mma_thread_acc_register_tensor); 590 | cute::copy(mma_thread_acc_register_tensor, mma_thread_acc_shared_tensor); 591 | __syncthreads(); 592 | Tensor global_max = make_tensor(make_gmem_ptr((ComputeType *)max_ptr + b * head * m + h * m + m_tile * kTileM), 593 | Shape, Int<1>>{}, 594 | make_stride(1, Int<1>{})); 595 | Tensor global_sum = make_tensor(make_gmem_ptr((ComputeType *)sum_ptr + b * head * m + h * m + m_tile * kTileM), 596 | Shape, Int<1>>{}, 597 | make_stride(1, Int<1>{})); 598 | #pragma unroll 599 | for (int i = 0; i < kTileM; i += 4) { 600 | auto thread_max = global_max(idx / kTileN + i, 0); 601 | auto thread_sum = 1.f / global_sum(idx / kTileN + i, 0); 602 | shared_acc(idx / kTileN + i, idx % kTileN) = __expf(scale * shared_acc(idx / kTileN + i, idx % kTileN) - thread_max) * thread_sum; 603 | } 604 | __syncthreads(); 605 | 606 | ElementType *fp16_acc_shm = (ElementType*)(acc_shm + cosize(SmemLayoutAcc{})); 607 | ElementType* d_output_shm = key_shm; 608 | auto shared_d_output = make_tensor(make_smem_ptr(d_output_shm), SmemLayoutDOutput{}); 609 | auto shared_d_output_transposed = make_tensor(make_smem_ptr(d_output_shm), SmemLayoutDOutputTransposed{}); 610 | shared_ptr = reinterpret_cast(d_output_shm); 611 | global_ptr = reinterpret_cast((ElementType *)d_output_ptr + b * head * m * k + h * k + m_tile * kTileM * head * k); 612 | #pragma unroll 613 | for (int i = 0; i < 4; i++) { 614 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 615 | } 616 | __syncthreads(); 617 | 618 | auto mma_thread_d_output_shared_tensor_b_matrix = thr_mma_tn.partition_B(shared_d_output_transposed); 619 | auto mma_thread_d_output_register_tensor_b_matrix = thr_mma_tn.partition_fragment_B(shared_d_output_transposed); // (MMA, MMA_M, MMA_K) 620 | cute::copy(mma_thread_d_output_shared_tensor_b_matrix, mma_thread_d_output_register_tensor_b_matrix); 621 | __syncthreads(); 622 | auto shared_acc_transposed = make_tensor(make_smem_ptr(acc_shm), SmemLayoutAccTransposed{}); 623 | auto mma_thread_acc_transposed_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_acc_transposed); 624 | auto mma_thread_acc_transposed_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_acc_transposed); 625 | cute::copy(mma_thread_acc_transposed_shared_tensor_a_matrix, mma_thread_acc_transposed_register_tensor_a_matrix); 626 | cute::gemm(thr_mma_tn, mma_thread_d_value_register_tensor, 627 | mma_thread_acc_transposed_register_tensor_a_matrix, 628 | mma_thread_d_output_register_tensor_b_matrix, 629 | mma_thread_d_value_register_tensor); 630 | 631 | auto mma_thread_d_output_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_d_output); // (MMA, MMA_M, MMA_K) 632 | auto mma_thread_d_output_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_d_output); 633 | cute::copy(mma_thread_d_output_shared_tensor_a_matrix, mma_thread_d_output_register_tensor_a_matrix); 634 | __syncthreads(); 635 | 636 | ComputeType *d_p_shm = (ComputeType*)(fp16_acc_shm); 637 | auto shared_d_p = make_tensor(make_smem_ptr(d_p_shm), SmemLayoutAcc{}); 638 | auto mma_thread_d_p_register_tensor = thr_mma_tn.partition_fragment_C(shared_d_p); 639 | auto mma_thread_d_p_shared_tensor = thr_mma_tn.partition_C(shared_d_p); 640 | clear(mma_thread_d_p_register_tensor); 641 | cute::gemm(thr_mma_tn, mma_thread_d_p_register_tensor, 642 | mma_thread_d_output_register_tensor_a_matrix, 643 | mma_thread_value_register_tensor, 644 | mma_thread_d_p_register_tensor); 645 | 646 | cute::copy(mma_thread_d_p_register_tensor, mma_thread_d_p_shared_tensor); 647 | __syncthreads(); 648 | 649 | Tensor global_d = make_tensor(make_gmem_ptr((ElementType *)d_ptr + b * head * m + h + m_tile * kTileM * head), 650 | Shape, Int<1>>{}, 651 | make_stride(head, Int<1>{})); 652 | 653 | #pragma unroll 654 | for (int i = 0; i < kTileM; i+=4) { 655 | shared_acc(i + idx / 32, idx % 32) *= (shared_d_p(i + idx / 32, idx % 32) - global_d(i + idx / 32, 0)) * scale; 656 | } 657 | 658 | __syncthreads(); 659 | auto mma_thread_d_s_fp32_shared_tensor_a_matrix = thr_mma_tn.partition_A(shared_acc); 660 | auto mma_thread_d_s_fp16_register_tensor_a_matrix = thr_mma_tn.partition_fragment_A(shared_acc); 661 | cute::copy(mma_thread_d_s_fp32_shared_tensor_a_matrix, mma_thread_d_s_fp16_register_tensor_a_matrix); 662 | Tensor global_d_query = make_tensor(make_gmem_ptr((ElementType *)d_query_ptr + b * head * m * k + h * k + m_tile * kTileM * head * k), 663 | Shape, Int>{}, 664 | make_stride(head * k, Int<1>{})); 665 | auto mma_thread_d_query_register_tensor = thr_mma_tn.partition_fragment_C(global_d_query); 666 | auto mma_thread_d_query_global_tensor = thr_mma_tn.partition_C(global_d_query); 667 | clear(mma_thread_d_query_register_tensor); 668 | auto shared_key_transposed = make_tensor(make_smem_ptr(key_shm), SmemLayoutKeyTransposed{}); 669 | shared_ptr = reinterpret_cast(key_shm); 670 | global_ptr = reinterpret_cast((ElementType *)key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 671 | #pragma unroll 672 | for (int i = 0; i < 4; i++) { 673 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 674 | } 675 | __syncthreads(); 676 | auto mma_thread_transposed_key_register_tensor = thr_mma_tn.partition_fragment_B(shared_key_transposed); // (MMA, MMA_N, MMA_K) 677 | auto mma_thread_transposed_key_shared_tensor = thr_mma_tn.partition_B(shared_key_transposed); 678 | cute::copy(mma_thread_transposed_key_shared_tensor, mma_thread_transposed_key_register_tensor); 679 | __syncthreads(); 680 | cute::gemm(thr_mma_tn, mma_thread_d_query_register_tensor, 681 | mma_thread_d_s_fp16_register_tensor_a_matrix, 682 | mma_thread_transposed_key_register_tensor, 683 | mma_thread_d_query_register_tensor); 684 | cute::half_t d_query_register_half2[2]; 685 | #pragma unroll 686 | for (int i = 0; i < size<1>(mma_thread_d_query_register_tensor); i++) { 687 | #pragma unroll 688 | for (int j = 0; j < size<2>(mma_thread_d_query_register_tensor); j++) { 689 | #pragma unroll 690 | for (int kk = 0; kk < size<0>(mma_thread_d_query_register_tensor); kk+=2) { 691 | d_query_register_half2[0] = (cute::half_t)mma_thread_d_query_register_tensor(kk, i, j); 692 | d_query_register_half2[1] = (cute::half_t)mma_thread_d_query_register_tensor(kk + 1, i, j); 693 | atomicAdd(reinterpret_cast<__half2*>(&mma_thread_d_query_global_tensor(kk, i, j)), *(reinterpret_cast<__half2*>(d_query_register_half2))); 694 | } 695 | } 696 | } 697 | cute::copy(mma_thread_acc_transposed_shared_tensor_a_matrix, mma_thread_acc_transposed_register_tensor_a_matrix); 698 | auto shared_query_transposed = make_tensor(make_smem_ptr(query_shm), SmemLayoutQueryTransposed{}); 699 | shared_ptr = reinterpret_cast(query_shm); 700 | global_ptr = reinterpret_cast((ElementType *)query_ptr + b * head * m * k + h * k + m_tile * kTileM * head * k); 701 | #pragma unroll 702 | for (int i = 0; i < 4; i++) { 703 | shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16] = global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16]; 704 | } 705 | __syncthreads(); 706 | auto mma_thread_query_register_tensor_b_matrix = thr_mma_tn.partition_fragment_B(shared_query_transposed); // (MMA, MMA_M, MMA_K) 707 | auto mma_thread_query_shared_tensor_b_matrix = thr_mma_tn.partition_B(shared_query_transposed); // (MMA, MMA_M, MMA_K) 708 | cute::copy(mma_thread_query_shared_tensor_b_matrix, mma_thread_query_register_tensor_b_matrix); 709 | cute::gemm(thr_mma_tn, mma_thread_d_key_register_tensor, 710 | mma_thread_acc_transposed_register_tensor_a_matrix, 711 | mma_thread_query_register_tensor_b_matrix, 712 | mma_thread_d_key_register_tensor); 713 | } 714 | ElementType *d_key_shm = (ElementType*)shm_data; 715 | auto shared_d_key = make_tensor(make_smem_ptr(d_key_shm), SmemLayoutKey{}); 716 | auto mma_thread_d_key_shared_tensor = thr_mma_tn.partition_C(shared_d_key); 717 | cute::copy(mma_thread_d_key_register_tensor, mma_thread_d_key_shared_tensor); 718 | __syncthreads(); 719 | float4* shared_ptr = reinterpret_cast(d_key_shm); 720 | float4* global_ptr = reinterpret_cast((ElementType *)d_key_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 721 | for (int i = 0; i < 4; i++) { 722 | global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16] = shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16]; 723 | } 724 | __syncthreads(); 725 | ElementType *d_value_shm = (ElementType*)shm_data; 726 | auto shared_d_value = make_tensor(make_smem_ptr(d_value_shm), SmemLayoutValue{}); 727 | auto mma_thread_d_value_shared_tensor = thr_mma_tn.partition_C(shared_d_value); 728 | cute::copy(mma_thread_d_value_register_tensor, mma_thread_d_value_shared_tensor); 729 | __syncthreads(); 730 | shared_ptr = reinterpret_cast(d_value_shm); 731 | global_ptr = reinterpret_cast((ElementType *)d_value_ptr + b * head * n * k + h * k + col_n * kTileN * head * k); 732 | for (int i = 0; i < 4; i++) { 733 | global_ptr[(idx / 16 + i * 8) * head * 16 + idx % 16] = shared_ptr[(idx / 16 + i * 8) * 16 + idx % 16]; 734 | } 735 | } 736 | 737 | namespace config { 738 | using namespace cute; 739 | template 740 | struct GemmConfig { 741 | using ElementType = ElementType_; 742 | using ComputeType = ComputeType_; 743 | // tile configuration 744 | static constexpr int kTileM = kTileM_; 745 | static constexpr int kTileN = kTileN_; 746 | static constexpr int kTileK = kTileK_; 747 | 748 | using SmemLayoutQuery = decltype(make_layout( 749 | make_shape(Int{}, Int{}), 750 | make_stride(Int{}, Int<1>{}) 751 | )); 752 | 753 | using SmemLayoutQueryTransposed = decltype(make_layout( 754 | make_shape(Int{}, Int{}), 755 | make_stride(Int<1>{}, Int{}) 756 | )); 757 | 758 | using SmemLayoutKey = decltype(make_layout( 759 | make_shape(Int{}, Int{}), 760 | make_stride(Int{}, Int<1>{}) 761 | )); 762 | 763 | using SmemLayoutKeyTransposed = decltype(make_layout( 764 | make_shape(Int{}, Int{}), 765 | make_stride(Int<1>{}, Int{}) 766 | )); 767 | 768 | using SmemLayoutValue = decltype(make_layout( 769 | make_shape(Int{}, Int{}), 770 | make_stride(Int{}, Int<1>{}) 771 | )); 772 | 773 | using SmemLayoutValueTransposed = decltype(make_layout( 774 | make_shape(Int{}, Int{}), 775 | make_stride(Int<1>{}, Int{}) 776 | )); 777 | 778 | using SmemLayoutAcc = decltype(make_layout( 779 | make_shape(Int{}, Int{}), 780 | make_stride(Int{}, Int<1>{}) 781 | )); 782 | 783 | using SmemLayoutAccTransposed = decltype(make_layout( 784 | make_shape(Int{}, Int{}), 785 | make_stride(Int<1>{}, Int{}) 786 | )); 787 | 788 | using SmemLayoutOutput = SmemLayoutQuery; 789 | 790 | using SmemLayoutDOutput = SmemLayoutQuery; 791 | 792 | using SmemLayoutDOutputTransposed = decltype(make_layout( 793 | make_shape(Int{}, Int{}), 794 | make_stride(Int<1>{}, Int{}) 795 | )); 796 | 797 | using SmemLayoutMax = decltype(make_layout( 798 | make_shape(Int{}, Int<1>{}), 799 | make_stride(Int<1>{}, Int<1>{}) 800 | )); 801 | 802 | using SmemLayoutSum = SmemLayoutMax; 803 | 804 | using mma_tn_op = SM70_8x8x4_F32F16F16F32_TN; 805 | using mma_tn_traits = MMA_Traits; 806 | using mma_tn_atom = MMA_Atom; 807 | 808 | static constexpr int kMmaEURepeatM = 4; 809 | static constexpr int kMmaEURepeatN = 4; 810 | 811 | static constexpr int kMmaVRepeatM = 1; 812 | static constexpr int kMmaVRepeatN = 1; 813 | 814 | using MMA_EU_RepeatT = decltype( 815 | make_layout(make_shape(Int{}, Int{})) 816 | ); 817 | using MMA_V_RepeatT = decltype( 818 | make_layout(make_shape(Int{}, Int{})) 819 | ); 820 | using MMA_TN = decltype(make_tiled_mma(mma_tn_atom{}, MMA_EU_RepeatT{}, MMA_V_RepeatT{})); 821 | 822 | // global mem to shared mem copy 823 | using Global2SharedCopyAtom = Copy_Atom; 824 | 825 | using Global2SharedCopyQuery = decltype( 826 | make_tiled_copy(Global2SharedCopyAtom{}, 827 | make_layout( 828 | make_shape(Int<1>{}, Int<128>{}), 829 | make_stride(Int<128>{}, Int<1>{}) 830 | ), 831 | make_layout( 832 | make_shape(Int<1>{}, Int<1>{}) 833 | ) 834 | ) 835 | ); 836 | 837 | using Global2SharedCopyKey = Global2SharedCopyQuery; 838 | using Global2SharedCopyDOutput = Global2SharedCopyQuery; 839 | using Global2SharedCopyOutput = Global2SharedCopyQuery; 840 | using Global2SharedCopyValue = decltype( 841 | make_tiled_copy(Global2SharedCopyAtom{}, 842 | make_layout( 843 | make_shape(Int<1>{}, Int<128>{}), 844 | make_stride(Int<128>{}, Int<1>{}) 845 | ), 846 | make_layout( 847 | make_shape(Int<1>{}, Int<1>{}) 848 | ) 849 | ) 850 | ); 851 | using Shared2RegisterCopyAtomAcc = Copy_Atom; 852 | using Shared2RegisterCopyAcc = decltype( 853 | make_tiled_copy(Shared2RegisterCopyAtomAcc{}, 854 | make_layout( 855 | make_shape(Int<4>{}, Int<32>{}), 856 | make_stride(Int<32>{}, Int<1>{}) 857 | ), 858 | make_layout( 859 | make_shape(Int<1>{}, Int<1>{}) 860 | ) 861 | ) 862 | ); 863 | using Shared2RegisterCopyFp16Acc = decltype( 864 | make_tiled_copy(Shared2RegisterCopyAtomAcc{}, 865 | make_layout( 866 | make_shape(Int<4>{}, Int<32>{}), 867 | make_stride(Int<32>{}, Int<1>{}) 868 | ), 869 | make_layout( 870 | make_shape(Int<1>{}, Int<1>{}) 871 | ) 872 | ) 873 | ); 874 | using Register2SharedCopyFp16Acc = Global2SharedCopyQuery; 875 | 876 | // register to global via shared memory 877 | using MNK = typename MMA_TN::TiledShape_MNK; 878 | 879 | static constexpr int kThreadNum = size(MMA_TN{}); 880 | static constexpr int backward_kShmSize = sizeof(ComputeType) * cosize(SmemLayoutAcc{}) * 2 + 881 | sizeof(ElementType) * cosize(SmemLayoutQuery{}) * 2 + 882 | sizeof(ComputeType) * cosize(SmemLayoutMax{}) * 2; 883 | static constexpr int shm_size_query_key = cute::max(cute::cosize(SmemLayoutQuery{}), cute::cosize(SmemLayoutKey{})) * sizeof(ElementType); 884 | static constexpr int shm_size_query_key_acc = cute::max(shm_size_query_key, cute::cosize(SmemLayoutAcc{}) * sizeof(ElementType)); 885 | static constexpr int kShmSizeQKVO = cute::max(shm_size_query_key_acc , cute::cosize(SmemLayoutValue{}) * sizeof(ElementType)) 886 | + cute::cosize(SmemLayoutOutput{}) * sizeof(ElementType); 887 | static constexpr int forward_kShmSize = kShmSizeQKVO + (cute::cosize(SmemLayoutMax{}) * 2 + cute::cosize(SmemLayoutSum{})) * sizeof(ComputeType); 888 | }; 889 | } 890 | 891 | void fused_mha_forward(const void *query_ptr, const void *key_ptr, const void *value_ptr, void *output_ptr, void *max_ptr, void *sum_ptr, 892 | int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream) 893 | { 894 | config::GemmConfig gemm_config; 895 | // print(typename decltype(gemm_config)::MMA{}); 896 | dim3 block = gemm_config.kThreadNum; 897 | dim3 grid((m + gemm_config.kTileM - 1) / gemm_config.kTileM, head * batch); 898 | int shm_size = gemm_config.forward_kShmSize; 899 | fused_mha_forward_kernel 900 | <<>>(query_ptr, key_ptr, value_ptr, 901 | output_ptr, max_ptr, sum_ptr, head, m, n, k, scale, causal); 902 | } 903 | 904 | void fused_mha_backward(const void *query_ptr, const void *key_ptr, const void *value_ptr, 905 | void *output_ptr, void *d_output_ptr, void *d_ptr, void *max_ptr, void *sum_ptr, 906 | void *d_query_ptr, void *d_key_ptr, void *d_value_ptr, int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream) 907 | { 908 | config::GemmConfig gemm_config; 909 | // print(typename decltype(gemm_config)::MMA{}); 910 | dim3 block = gemm_config.kThreadNum; 911 | dim3 grid((m + gemm_config.kTileM - 1) / gemm_config.kTileM, head * batch); 912 | int shm_size = gemm_config.backward_kShmSize; 913 | fused_mha_backward_kernel 914 | <<>>(query_ptr, key_ptr, value_ptr, output_ptr, d_output_ptr, d_ptr, 915 | max_ptr, sum_ptr, d_query_ptr, d_key_ptr, d_value_ptr, head, m, n, k, scale, causal); 916 | } --------------------------------------------------------------------------------