├── .gitignore ├── .gitmodules ├── README.md ├── benchmark.png ├── csrc ├── attention_api.cpp ├── flash.h ├── flash_api.cpp ├── flash_attention.cu ├── kernel_traits.h ├── static_switch.h └── utils.h ├── include ├── attention_api.cuh └── attention_api.h ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | .cache/ 3 | dist/ 4 | tiny_attention_cutlass.egg-info 5 | deps/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "deps/cutlass"] 2 | path = deps/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 本实现是在 https://github.com/66RING/tiny-flash-attention.git 基础上做了一些是实现上的代码简化,如去掉了一些不必要的冗余定义、简化shared memory cute layout 定义等 2 | 3 | 对应的 blog:https://zhuanlan.zhihu.com/p/708867810 4 | 5 | ![测试](./benchmark.png) 6 | -------------------------------------------------------------------------------- /benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weishengying/tiny-flash-attention/ba3e46f195985e032322cb9f4ae490270dbe92ef/benchmark.png -------------------------------------------------------------------------------- /csrc/attention_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_api.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // m.def("package_name", &function_name, "function_docstring"") 8 | m.def("flash_attention_v2_cutlass", &flash_attention_v2_cutlass, 9 | "Flash attention v2 implement in cutlass"); 10 | } 11 | -------------------------------------------------------------------------------- /csrc/flash.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // TODO: 特种约束字段, e.g. __restrict__ 的效果 6 | struct Qkv_params { 7 | using index_t = uint32_t; 8 | // The QKV matrices. 9 | void *__restrict__ q_ptr; 10 | void *__restrict__ k_ptr; 11 | void *__restrict__ v_ptr; 12 | 13 | // // The stride between rows of the Q, K and V matrices. 14 | // index_t q_batch_stride; 15 | // index_t k_batch_stride; 16 | // index_t v_batch_stride; 17 | // // TODO: 18 | // index_t q_row_stride; 19 | // index_t k_row_stride; 20 | // index_t v_row_stride; 21 | // index_t q_head_stride; 22 | // index_t k_head_stride; 23 | // index_t v_head_stride; 24 | 25 | bool is_bf16; 26 | }; 27 | 28 | 29 | struct Flash_fwd_params : public Qkv_params { 30 | size_t bs; 31 | size_t head; 32 | size_t q_seqlen; 33 | size_t dim; 34 | 35 | size_t k_head; 36 | size_t k_seqlen; 37 | 38 | // TODO: review the impl of flash 39 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 40 | // different from nheads (query). 41 | size_t h_h_k_ratio; // precompute head / k_head, 42 | size_t flat_seqlen; 43 | size_t kv_head_stride; 44 | size_t qo_head_stride; 45 | 46 | 47 | size_t bs_stride; 48 | size_t head_stride; 49 | size_t seqlen_stride; 50 | size_t dim_stride; 51 | 52 | float softmax_scale; 53 | float softmax_scale_log2; 54 | void *__restrict__ out_ptr; 55 | void *__restrict__ softmax_lse_ptr; 56 | void *__restrict__ score_max; 57 | void *__restrict__ score_sum; 58 | 59 | bool is_causal; 60 | }; 61 | 62 | -------------------------------------------------------------------------------- /csrc/flash_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_api.h" 5 | #include "flash.h" 6 | -------------------------------------------------------------------------------- /csrc/flash_attention.cu: -------------------------------------------------------------------------------- 1 | #include "attention_api.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "static_switch.h" 15 | #include "kernel_traits.h" 16 | #include "flash.h" 17 | #include "utils.h" 18 | 19 | namespace flash { 20 | 21 | using namespace cute; 22 | 23 | template 24 | inline __device__ void mask_within_nblock(Tensor &tensor, const int m_block, const int nbi) { 25 | // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) 26 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 27 | // NOTE: 根据 mma_tile 的示意图来确定每个线程处理的是第几个 token 28 | 29 | // NOTE: 30 | // 计算thread的处理范围, mask掉超出范围的部分 31 | 32 | const int lane_id = threadIdx.x % 32; 33 | const int col_idx_offset = kBlockN * nbi + (lane_id % 4) * 2; 34 | 35 | const int nrow_group = threadIdx.x / 32; 36 | const int row_idx_offset = kBlockM * m_block + lane_id / 4 + nrow_group * 16 /* 2*8 */; 37 | // (2, nrow), 2*8 for each 38 | const int group_stride = kNWarps * 16; 39 | 40 | #pragma unroll 41 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 42 | // SM80_16x8x16_F32F16F16F32_TN中的一组中, 一行4个线程处理8个value 43 | const int col_idx_base = col_idx_offset + nj * 8; 44 | #pragma unroll 45 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 46 | // j用于计算value 1和value 2对应col 47 | // col_idx最终表示当前thread所处理的value的列号 48 | const int col_idx = col_idx_base + j; 49 | 50 | // mask掉scores中(QK后的结果)超出范围的部分 51 | // 列号和行号对比 52 | 53 | // Without the "make_coord" we get wrong results 54 | // for nrow(2, MMA_M) 55 | #pragma unroll 56 | for (int mi = 0; mi < size<0, 0>(tensor); ++mi) { 57 | 58 | #pragma unroll 59 | for (int mj = 0; mj < size<0, 1>(tensor); ++mj) { 60 | const int row_idx = row_idx_offset + mi * 8 + mj * group_stride; 61 | if (col_idx > row_idx) { 62 | tensor(make_coord(mi, mj), make_coord(j, nj)) = -INFINITY; 63 | } 64 | } 65 | 66 | } 67 | 68 | } 69 | } 70 | } 71 | 72 | // NOTE: A矩阵已经在寄存器中的gemm封装 73 | template 75 | inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, 76 | TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, 77 | ThrCopy smem_thr_copy_B) { 78 | // NOTE: 符合M N K描述: A[M, K] @ B[N, K] = C[M, N] 79 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 80 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 81 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 82 | // NOTE: retile 成拷贝需要的大小 83 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 84 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 85 | 86 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 87 | #pragma unroll 88 | for (int i = 0; i < size<2>(tCrA); ++i) { 89 | if (i < size<2>(tCrA) - 1) { 90 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 91 | } 92 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 93 | } 94 | } 95 | 96 | template 100 | inline __device__ void gemm_smem(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 101 | Tensor4 const& tCsB, TiledMma tiled_mma, 102 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 103 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 104 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 105 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 106 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 107 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 108 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 109 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 110 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 111 | // NOTE: s -> reg 112 | cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); 113 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 114 | #pragma unroll 115 | for (int i = 0; i < size<2>(tCrA); ++i) { 116 | if (i < size<2>(tCrA) - 1) { 117 | cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); 118 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 119 | } 120 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 121 | } 122 | } 123 | 124 | // Blocks until all but N previous cp.async.commit_group operations have committed. 125 | // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all 126 | // (which is equivalent to commit_group then wait_group 0). 127 | // Instead we just call cp.async.wait_group 0, which is slightly faster. 128 | // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 129 | template 130 | CUTE_HOST_DEVICE 131 | void cp_async_wait() { 132 | #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) 133 | asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); 134 | #endif 135 | } 136 | 137 | // copy from S to D with tiled_copy 138 | // TODO: 需要支持causal模式的的跳过拷贝 139 | template 140 | inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, 141 | Tensor &D) { 142 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 143 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 144 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 145 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 146 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 147 | 148 | #pragma unroll 149 | for (int m = 0; m < size<1>(S); ++m) { 150 | // TODO: 原版处这里identity_MN是用来跳过大块的block的, predicate用于跳过block内的拷贝 151 | // TODO: 添加predicate逻辑, 用于跳过无用拷贝 152 | // if (get<0>(identity_MN(0, m, 0)) < max_MN) 153 | #pragma unroll 154 | for (int k = 0; k < size<2>(S); ++k) { 155 | cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); 156 | } 157 | } 158 | } 159 | 160 | 161 | // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) 162 | // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. 163 | template 164 | inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { 165 | using X = Underscore; 166 | static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); 167 | static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); 168 | constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); 169 | static_assert(mma_shape_K == 8 || mma_shape_K == 16); 170 | constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; 171 | auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) 172 | // TD [2023-08-13]: Same error as above on Cutlass 3.2 173 | // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), 174 | // get<0, 1>(l), 175 | // get<1, 1, 1>(l)); 176 | return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), 177 | get<1>(get<0>(l)), 178 | get<1>(get<1>(get<1>(l)))); 179 | }; 180 | 181 | 182 | // TODO: not work 183 | template 184 | inline __device__ auto convert_type(Tensor const &tensor) { 185 | using From_type = typename Engine::value_type; 186 | constexpr int numel = decltype(size(tensor))::value; 187 | cutlass::NumericArrayConverter convert_op; 188 | // HACK: this requires tensor to be "contiguous" 189 | auto frag = convert_op(*reinterpret_cast *>(tensor.data())); 190 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 191 | } 192 | 193 | 194 | // https://github.com/NVIDIA/cutlass/issues/802 195 | // TODO: convert出来后数据是否在寄存器? 196 | template 197 | inline __device__ auto convert_type_f32_to_f16(Fragment const &acc_fp32) { 198 | Tensor acc_fp16 = make_tensor(shape(acc_fp32)); 199 | { 200 | Tensor acc_fp32x2 = recast< float2>(acc_fp32); 201 | Tensor acc_fp16x2 = recast<__half2>(acc_fp16); 202 | for (int i = 0; i < size(acc_fp32x2); ++i) { acc_fp16x2(i) = __float22half2_rn(acc_fp32x2(i)); } 203 | } 204 | return acc_fp16; 205 | } 206 | 207 | // Apply the exp to all the elements. 208 | template 209 | inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { 210 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 211 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 212 | CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); 213 | #pragma unroll 214 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 215 | // If max is -inf, then all elements must have been -inf (possibly due to masking). 216 | // We don't want (-inf - (-inf)) since that would give NaN. 217 | // If we don't have float around M_LOG2E the multiplication is done in fp64. 218 | const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); 219 | #pragma unroll 220 | for (int ni = 0; ni < size<1>(tensor); ++ni) { 221 | // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - 222 | // max * log_2(e)) This allows the compiler to use the ffma 223 | // instruction instead of fadd and fmul separately. 224 | tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled); 225 | } 226 | } 227 | } 228 | 229 | 230 | 231 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) 232 | template 233 | inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { 234 | static_assert(decltype(size<0>(acc_layout))::value == 4); 235 | static_assert(decltype(rank(acc_layout))::value == 3); 236 | auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) 237 | // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting 238 | // "int_tuple.hpp(74): error: conversion to inaccessible base class" 239 | // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); 240 | return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); 241 | }; 242 | 243 | // scores:((2, MMA_M),(2, MMA_N)),经过了 causal 之后的 Q_i 和 k_j^T 的乘积, 244 | // scores_max:(2 * MMA_N), rowmax 的结果 245 | // scores_sum:(2 * MMA_N), rowsum 的结果 246 | // acc_o:((2, 2),(MMA_M, MMA_N)), 最后的计算结果 247 | template 248 | inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, 249 | Tensor2 &acc_o, float softmax_scale_log2) { 250 | if (Is_first) { 251 | // NOTE: 第一次softmax不需要rescale, 只需要记录 Sij(kblockM, kblockN) 的 rowmax 和 rowsum 252 | reduce_max(scores, scores_max); 253 | flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); 254 | reduce_sum(scores, scores_sum); 255 | } else { 256 | // 记录上一次的 rowmax 257 | Tensor scores_max_prev = make_fragment_like(scores_max); // 相当于公式中的 m_i^{j-1} 258 | cute::copy(scores_max, scores_max_prev); 259 | // NOTE: 计算最新的 max 260 | // reduce_max包含步: 261 | // 1. 求当前thread内max: 遍历 262 | // 2. reduce thread间的max: 使用线程数洗牌指令做 all reduce,每个线程都获得了最大值 263 | reduce_max(scores, scores_max); // scores_max 变成最新的最大值,相当于公式中的 m_i^{j} 264 | // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) 265 | // 将acc_o转换成符合2D直觉的(nrow, ncol)的形状 266 | Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); 267 | #pragma unroll 268 | for (int mi = 0; mi < size(scores_max); ++mi) { // 遍历每一行 269 | // NOTE: 辅助变量: 当前行max 270 | float scores_max_cur = scores_max(mi); // 当前行的最大值 271 | // NOTE: 计算上一次 score_sum 的 rescale 值 272 | float scores_scale = expf((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); // 想当于公式中的 e^{m_i^{j-1} - m_i^{j}}. 273 | scores_sum(mi) *= scores_scale; // 想当于公式中的 e^{m_i^{j-1} - m_i^{j}}l_i^{j-1} 274 | #pragma unroll 275 | for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } // 想当于公式中的 e^{m_i^{j-1} - m_i^{j}}O_i^{j-1} 276 | } 277 | // NOTE: Apply the exp to all the elements with new max value, 这里相当于论文公式里的 P_i^_j 278 | flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); 279 | 280 | Tensor scores_sum_cur = make_fragment_like(scores_sum); // l_i^{j} = e^{m_i^{j-1} - m_i^{j}}O_i^{j-1} 281 | // NOTE: 累计求和 282 | reduce_sum(scores, scores_sum_cur); // rowsum(P_i^_j) 283 | // NOTE: 新分母累加到旧分母 284 | #pragma unroll 285 | for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } // l{ij} = e^{m_i^{j-1} - m_i^{j}}O_i^{j-1} + rowsum(P_i^_j) 286 | } 287 | }; 288 | 289 | } // namespace flash 290 | 291 | void set_params_fprop(Flash_fwd_params ¶ms, 292 | 293 | // device pointers 294 | const torch::Tensor q, 295 | const torch::Tensor k, 296 | const torch::Tensor v, 297 | torch::Tensor out, 298 | 299 | void *softmax_lse_d, 300 | float softmax_scale, 301 | bool is_causal) { 302 | 303 | memset(¶ms, 0, sizeof(params)); 304 | 305 | params.bs = q.size(0); 306 | params.head = q.size(1); 307 | params.q_seqlen = q.size(2); 308 | params.dim = q.size(3); 309 | 310 | params.k_head = k.size(1); 311 | params.k_seqlen = k.size(2); 312 | 313 | params.bs_stride = q.stride(0); 314 | params.head_stride = q.stride(1); 315 | params.seqlen_stride = q.stride(2); 316 | params.dim_stride = q.stride(3); 317 | 318 | params.softmax_scale = softmax_scale; 319 | // TODO: 使用log2做scale 320 | params.softmax_scale_log2 = softmax_scale * M_LOG2E; 321 | params.is_causal = is_causal; 322 | params.is_bf16 = q.dtype() == torch::kBFloat16; 323 | 324 | // LogSumExp save for backward 325 | params.softmax_lse_ptr = softmax_lse_d; 326 | 327 | params.q_ptr = q.data_ptr(); 328 | params.k_ptr = k.data_ptr(); 329 | params.v_ptr = v.data_ptr(); 330 | params.out_ptr = out.data_ptr(); 331 | } 332 | 333 | 334 | // Shared Storage with Aligned addresses. 335 | template 336 | struct SharedStorage { 337 | // TODO: Aligned的话smem的计算是否有问题 338 | cute::array_aligned> smem_q; 339 | cute::array_aligned> smem_k; 340 | cute::array_aligned> smem_v; 341 | }; 342 | 343 | template 344 | __global__ void flash_attention_v2_cutlass_kernel(const Params params) { 345 | 346 | using namespace cute; 347 | 348 | // m block index 349 | const int m_block = blockIdx.x; 350 | 351 | // bs * head 352 | const int base_id = blockIdx.y; 353 | // The thread index. 354 | const int tidx = threadIdx.x; 355 | 356 | using Element = typename Kernel_traits::Element; 357 | using ElementAccum = typename Kernel_traits::ElementAccum; 358 | // using TiledMMA = typename Kernel_traits::MMA; 359 | using TiledMMA = typename Kernel_traits::TiledMma; 360 | using index_t = typename Kernel_traits::index_t; 361 | using SmemLayoutQ = typename Kernel_traits::SmemLayoutQ; 362 | using SmemLayoutK = typename Kernel_traits::SmemLayoutKV; 363 | using SmemLayoutV = typename Kernel_traits::SmemLayoutKV; 364 | using SmemLayoutVt = typename Kernel_traits::SmemLayoutVtransposed; 365 | using SmemLayoutVtNswizzle = typename Kernel_traits::SmemLayoutVtransposedNoSwizzle; 366 | 367 | constexpr int kNWarps = Kernel_traits::kNWarps; 368 | constexpr int kBlockM = Kernel_traits::kBlockM; 369 | constexpr int kBlockN = Kernel_traits::kBlockN; 370 | constexpr int kHeadDim = Kernel_traits::kHeadDim; 371 | 372 | // Shared memory. 373 | extern __shared__ char smem_[]; 374 | using SharedStorage = SharedStorage; 375 | SharedStorage &shared_storage = *reinterpret_cast(smem_); 376 | 377 | const int bs_head_offset = base_id * params.head_stride; 378 | 379 | // TODO: base offset for MHA 380 | // NOTE: convert C pointer to Tensor for convenience 381 | Tensor Q = make_tensor( 382 | make_gmem_ptr(reinterpret_cast(params.q_ptr) + bs_head_offset), 383 | make_shape(params.q_seqlen, Int{}), 384 | make_stride(Int{}, Int<1>{})); 385 | Tensor K = make_tensor( 386 | make_gmem_ptr(reinterpret_cast(params.k_ptr) + bs_head_offset), 387 | make_shape(params.k_seqlen, Int{}), 388 | make_stride(Int{}, Int<1>{})); 389 | Tensor V = make_tensor( 390 | make_gmem_ptr(reinterpret_cast(params.v_ptr) + bs_head_offset), 391 | make_shape(params.k_seqlen, Int{}), 392 | make_stride(Int{}, Int<1>{})); 393 | Tensor O = make_tensor( 394 | make_gmem_ptr(reinterpret_cast(params.out_ptr) + bs_head_offset), 395 | make_shape(params.q_seqlen, Int{}), 396 | make_stride(Int{}, Int<1>{})); 397 | 398 | 399 | // 加载Q, K, V分块 400 | // (kBlockM, kHeadDim, num_tile_n) 401 | Tensor gQ = local_tile(Q, make_tile(Int{}, Int{}), make_coord(m_block, _)); 402 | 403 | // (kBlockN, kHeadDim, num_tile_n) 404 | // NOTE: loading流水线, 初次加载所需K, V 405 | Tensor gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(0, _)); 406 | Tensor gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(0, _)); 407 | 408 | // 获取MMA抽象 409 | TiledMMA tiled_mma; 410 | auto thr_mma = tiled_mma.get_slice(tidx); 411 | 412 | // Construct SMEM tensors. 413 | Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); 414 | Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); 415 | Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); 416 | 417 | // Tensor for V Transpose; used in GEMM-II. 418 | Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); 419 | Tensor sVtNoSwizzle = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVtNswizzle{}); 420 | 421 | // NOTE: copy抽象 422 | // NOTE: QKV gmem -> smem 拷贝的抽象 423 | typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; 424 | auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); 425 | 426 | // NOTE: 定义gmem -> smem拷贝的src, dst 427 | Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0)); 428 | Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); 429 | Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 430 | Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); 431 | Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 432 | Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); 433 | 434 | // NOTE: 定义smem -> reg 拷贝的dst 435 | // partition_fragment与partition类似, 只是返回的是寄存器表示 436 | Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) 437 | Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) 438 | Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) 439 | 440 | // NOTE: 准备拷贝Q, K 到 reg 的copy对象 (smem --> reg) 441 | auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); 442 | auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); 443 | Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); 444 | 445 | auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); 446 | auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); 447 | Tensor tSsK = smem_thr_copy_K.partition_S(sK); 448 | 449 | // 拷贝时转置 450 | // NOTE: 拷贝Vt smem->reg 451 | auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); 452 | auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); 453 | Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); 454 | 455 | // NOTE: 命名规则, t表示to, s/g表示位置(smem, gmem) 456 | // 从smem加载时做retiling 457 | // tKgK表示gmem中的K, 用作gmem->smem的src 458 | // tKsK表示smem中的K, 用作gmem->smem的dst 459 | // tSsK表示smem中的K, 用作smem->reg的src 460 | 461 | // 流水线加载初始Q, K 462 | // 加载Q到smem 463 | flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ); 464 | // 加载K到smem 465 | flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK); 466 | // 开始执行异步拷贝 467 | cute::cp_async_fence(); 468 | 469 | Tensor rAccOut = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) 470 | 471 | // step1: slice-k compute QK block 472 | // Q[BLOCK_M, BLOCK_N] @ K[BLOCK_M, BLOCK_N].T = O[BLOCK_M, BLOCK_M] 473 | // 474 | // step2: 475 | // advance K, V 476 | 477 | // NOTE: K, V分块的数量: 处理的区间 478 | const int n_block_min = 0; 479 | // NOTE: 1. mask between N BLOCKs if is causal mode 480 | int seqlen_start = m_block * kBlockM; 481 | int seqlen_end = (m_block + 1) * kBlockM; 482 | int n_block_max = Is_causal ? cute::ceil_div(seqlen_end, kBlockN) : cute::ceil_div(params.k_seqlen, kBlockN); // (2 * MMA_M) 483 | 484 | // NOTE: 需要记录的max 485 | Tensor scores_max = make_tensor(Shape(rAccOut)>>{}); 486 | 487 | // NOTE: 需要记录的denom 488 | Tensor scores_sum = make_fragment_like(scores_max); 489 | 490 | clear(rAccOut); 491 | 492 | for (int nbi = n_block_min; nbi < n_block_max; nbi++) { 493 | auto rAccScore = partition_fragment_C(tiled_mma, make_shape(Int{}, Int{})); // (MMA, MMA_M, MMA_N) 494 | clear(rAccScore); // 初始化为 0 495 | 496 | // 等待Q, K的gmem -> smem拷贝完成, 即Q, K就绪 497 | // wait<0>表示等待还剩0个未完成 498 | flash::cp_async_wait<0>(); 499 | __syncthreads(); 500 | 501 | // gemm的同时异步加载V 502 | gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 503 | tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 504 | // 异步加载V到smem 505 | flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV); 506 | // 发起异步拷贝 507 | cute::cp_async_fence(); 508 | 509 | // O = Q@K.T 510 | // NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile** 511 | flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, 512 | smem_thr_copy_Q, smem_thr_copy_K 513 | ); 514 | 515 | Tensor scores = make_tensor(rAccScore.data(), flash::convert_layout_acc_rowcol(rAccScore.layout())); // (MMA, MMA_M, MMA_N) --> ((2, MMA_M), (2, MMA_N)) 516 | 517 | // NOTE: 2. mask within N BLOCKs 518 | if (Is_causal == true && nbi * kBlockN >= seqlen_start) { 519 | flash::mask_within_nblock(scores, m_block, nbi); 520 | } 521 | 522 | // NOTE: 等待V加载完成, 为下个K加载准备初始状态 523 | flash::cp_async_wait<0>(); 524 | __syncthreads(); 525 | 526 | // advance K 527 | if (nbi != n_block_max - 1) { 528 | gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 529 | tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 530 | flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK); 531 | cute::cp_async_fence(); 532 | } 533 | 534 | // 计算softmax 535 | // scores:((2, MMA_M),(2, MMA_N)), Q_i * K_j^T 的值 536 | // scores_max:(2 * MMA_N) 537 | // scores_sum:(2 * MMA_N) 538 | // rAccOut:((2, 2),(MMA_M, MMA_N)),相当于 O_i 539 | nbi == 0 ? flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale) : 540 | flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale); 541 | 542 | // 计算完成后, scores 相当于公式中的 P_i^j 543 | // 实际执行 P_i^j @ V 544 | // (score AKA rAccScore): exp(QK[M, N] - m_i^j) @ V[N, dim] 545 | // NOTE: DABC: F32F16F16F32, convert D type(F32) to A type(F16) 546 | Tensor rP = flash::convert_type_f32_to_f16(rAccScore); 547 | 548 | // NOTE: Convert from layout C to layout A; ((2, MMA_M),(2, MMA_N)) --> ((2, 2),(MMA_M, MMA_N)) 549 | Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); 550 | 551 | // if(cute::thread0() && nbi == 0){ 552 | // printf("tOrP: "); print(tOrP); printf("\n"); 553 | // printf("tOrVt: "); print(tOrVt); printf("\n"); 554 | // } 555 | // rAccOut:((2, 2),(MMA_M, MMA_N)) 556 | flash::gemm_A_in_regs(rAccOut, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); 557 | } 558 | 559 | // Epilogue 560 | // NOTE: 最后统一除上分母部分 561 | // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) 562 | // AKA reshape to (nrow, ncol) but with specific MMA layout 563 | Tensor acc_o_rowcol = make_tensor(rAccOut.data(), flash::convert_layout_acc_rowcol(rAccOut.layout())); 564 | 565 | // for row 566 | #pragma unroll 567 | for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { 568 | float sum = scores_sum(mi); 569 | float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; 570 | float scale = inv_sum; 571 | // for col 572 | #pragma unroll 573 | for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { 574 | acc_o_rowcol(mi, ni) *= scale; 575 | } 576 | } 577 | 578 | // Convert acc_o from fp32 to fp16/bf16 579 | Tensor rO = flash::convert_type_f32_to_f16(rAccOut); 580 | // 复用sQ的smem做sO的拷出 581 | Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) 582 | 583 | // Partition sO to match the accumulator partitioning 584 | auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); 585 | auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); 586 | Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) 587 | Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) 588 | 589 | // NOTE: 先拷贝到smem 590 | cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); 591 | 592 | Tensor gO = local_tile(O, make_tile(Int{}, Int{}), make_coord(m_block, _)); 593 | 594 | // 创建到smem -> gmem的拷贝 595 | typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; 596 | auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); 597 | Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) 598 | Tensor tOgO = gmem_thr_copy_O.partition_D(gO(_, _, 0)); 599 | 600 | __syncthreads(); 601 | 602 | // NOTE:: 再拷贝到gmem 603 | cute::copy(gmem_tiled_copy_O, tOsO, tOgO); 604 | 605 | } 606 | 607 | template 608 | void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { 609 | using Element = typename Kernel_traits::Element; 610 | using SmemLayoutQ = typename Kernel_traits::SmemLayoutQ; 611 | using SmemLayoutK = typename Kernel_traits::SmemLayoutKV; 612 | using SmemLayoutV = typename Kernel_traits::SmemLayoutKV; 613 | 614 | const int num_m_block = 615 | (params.q_seqlen + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; 616 | 617 | dim3 grid(num_m_block, params.bs * params.head, 1); 618 | dim3 block(Kernel_traits::kNThreads); 619 | 620 | int smem_size = int(sizeof(SharedStorage)); 621 | 622 | auto kernel = &flash_attention_v2_cutlass_kernel; 623 | // NOTE: smem过大时需要设置 624 | if (smem_size >= 48 * 1024) { 625 | CUDA_ERROR_CHECK(cudaFuncSetAttribute( 626 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 627 | } 628 | 629 | // TODO: stream 630 | kernel<<>>(params); 631 | } 632 | 633 | template 634 | void run_flash_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); 635 | 636 | // TODO: 挨个写出特化, 目前使用通用模板 637 | // 如, run_flash_fwd_hdim32用于特化hdim=32 638 | // 这样做可以根据实际情况微调kBlockN和kBlockM的组合, 也可以加速编译 639 | template 640 | void run_flash_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 641 | BOOL_SWITCH(params.is_causal, Is_causal, [&] { 642 | // run_flash_fwd, Is_causal>(params, stream); 643 | 644 | // TODO: kBlockM, kBlockN的组合 645 | run_flash_fwd, Is_causal>(params, stream); 646 | }); 647 | } 648 | 649 | // entry point of flash attention 650 | void run_flash_attn_cutlass(Flash_fwd_params ¶ms, cudaStream_t stream) { 651 | // FP16_SWITCH yield elem_type namespace 652 | FP16_SWITCH(!params.is_bf16, [&] { 653 | // FWD_HEADDIM_SWITCH yield kHeadDim constexpr 654 | FWD_HEADDIM_SWITCH(params.dim, [&] { 655 | run_flash_fwd_(params, stream); 656 | }); 657 | }); 658 | } 659 | 660 | std::vector flash_attention_v2_cutlass(torch::Tensor q, torch::Tensor k, 661 | torch::Tensor v, bool is_causal = false, float softmax_scale=1) { 662 | 663 | CHECK_INPUT(q); 664 | CHECK_INPUT(k); 665 | CHECK_INPUT(v); 666 | 667 | // batch size 668 | int bs = q.size(0); 669 | // head number 670 | int head = q.size(1); 671 | // seqlen 672 | int seqlen = q.size(2); 673 | // dim 674 | int dim = q.size(3); 675 | auto out = torch::empty_like(q); 676 | 677 | Flash_fwd_params params; 678 | set_params_fprop(params, q, k, v, out, 679 | nullptr, softmax_scale, is_causal); 680 | 681 | run_flash_attn_cutlass(params, 0); 682 | 683 | // Wait until kernel finish. 684 | cudaDeviceSynchronize(); 685 | CUDA_ERROR_CHECK(cudaGetLastError()); 686 | 687 | return {out}; 688 | } 689 | 690 | 691 | 692 | -------------------------------------------------------------------------------- /csrc/kernel_traits.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cute/algorithm/copy.hpp" 4 | 5 | #include "cutlass/cutlass.h" 6 | #include "cutlass/layout/layout.h" 7 | #include 8 | 9 | using namespace cute; 10 | 11 | template 12 | struct Flash_kernel_traits { 13 | 14 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 15 | using Element = elem_type; 16 | static constexpr bool Has_cp_async = true; 17 | #else 18 | using Element = cutlass::half_t; 19 | static constexpr bool Has_cp_async = false; 20 | #endif 21 | 22 | using ElementAccum = float; 23 | using index_t = uint32_t; 24 | 25 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 26 | using MMA_Atom_Arch = std::conditional_t< 27 | std::is_same_v, 28 | MMA_Atom, 29 | MMA_Atom 30 | >; 31 | using ValLayoutMNK = Layout>; 32 | #else 33 | using MMA_Atom_Arch = MMA_Atom; 34 | using ValLayoutMNK = Layout>; 35 | #endif 36 | 37 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 38 | using SmemCopyAtom = Copy_Atom; 39 | using SmemCopyAtomTransposed = Copy_Atom; 40 | #else 41 | using SmemCopyAtom = Copy_Atom; 42 | using SmemCopyAtomTransposed = Copy_Atom; 43 | #endif 44 | }; 45 | 46 | 47 | // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true 48 | template > 50 | struct Flash_fwd_kernel_traits : public Base { 51 | using Element = typename Base::Element; 52 | using ElementAccum = typename Base::ElementAccum; 53 | using index_t = typename Base::index_t; 54 | static constexpr bool Has_cp_async = Base::Has_cp_async; 55 | using SmemCopyAtom = typename Base::SmemCopyAtom; 56 | using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 57 | 58 | // The number of threads. 59 | static constexpr int kNWarps = kNWarps_; 60 | static constexpr int kNThreads = kNWarps * 32; 61 | 62 | static constexpr int kBlockM = kBlockM_; 63 | static constexpr int kBlockN = kBlockN_; 64 | static constexpr int kHeadDim = kHeadDim_; 65 | 66 | // TODO: review 67 | static_assert(kHeadDim % 32 == 0); 68 | static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 69 | static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 70 | 71 | using TiledMma = TiledMMA< 72 | typename Base::MMA_Atom_Arch, 73 | Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group 74 | Tile, _16, _16>>; 75 | 76 | using SmemLayoutAtom = decltype( 77 | composition(Swizzle{}, 78 | Layout>, 79 | Stride, _1>>{})); 80 | using SmemLayoutQ = decltype(tile_to_shape( 81 | SmemLayoutAtom{}, 82 | Shape, Int>{})); 83 | 84 | using SmemLayoutKV = decltype(tile_to_shape( 85 | SmemLayoutAtom{}, 86 | Shape, Int>{})); 87 | 88 | // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 89 | // 这样定义的 SmemLayoutVtransposed 是 SmemLayoutKV 正确转置,可以打印验证 90 | using SmemLayoutVtAtom = decltype( 91 | composition(Swizzle{}, 92 | Layout, Int>, 93 | Stride<_1, Int>>{})); 94 | 95 | using SmemLayoutVtransposed = decltype(tile_to_shape( 96 | SmemLayoutVtAtom{}, 97 | Shape, Int>{})); 98 | 99 | using SmemLayoutVtransposedNoSwizzle = Layout, Int>, 100 | Stride<_1, Int>>; 101 | 102 | using SmemLayoutAtomO = decltype( 103 | composition(Swizzle{}, 104 | Layout, Int>, 105 | Stride, _1>>{})); 106 | using SmemLayoutO = decltype(tile_to_shape( 107 | SmemLayoutAtomO{}, 108 | Shape, Int>{})); 109 | using SmemCopyAtomO = Copy_Atom; 110 | 111 | 112 | static constexpr int kSmemQCount = size(SmemLayoutQ{}); 113 | static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; 114 | static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); 115 | static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); 116 | static constexpr int kSmemSize = kSmemQSize + kSmemKVSize; 117 | 118 | static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 119 | static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); 120 | 121 | static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 122 | static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); 123 | using GmemLayoutAtom = Layout, Int>, 124 | Stride, _1>>; 125 | 126 | // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 127 | // from the same address by the same threadblock. This is slightly faster. 128 | using Gmem_copy_struct = std::conditional_t< 129 | Has_cp_async, 130 | SM80_CP_ASYNC_CACHEGLOBAL, 131 | DefaultCopy 132 | >; 133 | using GmemTiledCopyQKV = decltype( 134 | make_tiled_copy(Copy_Atom{}, 135 | GmemLayoutAtom{}, 136 | Layout>{})); // Val layout, 8 vals per read 137 | using GmemTiledCopyO = decltype( 138 | make_tiled_copy(Copy_Atom{}, 139 | GmemLayoutAtom{}, 140 | Layout>{})); // Val layout, 8 vals per store 141 | }; 142 | -------------------------------------------------------------------------------- /csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, [&] { 14 | /// some_function(...); 15 | /// }); 16 | /// ``` 17 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 18 | [&] { \ 19 | if (COND) { \ 20 | constexpr static bool CONST_NAME = true; \ 21 | return __VA_ARGS__(); \ 22 | } else { \ 23 | constexpr static bool CONST_NAME = false; \ 24 | return __VA_ARGS__(); \ 25 | } \ 26 | }() 27 | 28 | #define FP16_SWITCH(COND, ...) \ 29 | [&] { \ 30 | if (COND) { \ 31 | using elem_type = cutlass::half_t; \ 32 | return __VA_ARGS__(); \ 33 | } else { \ 34 | using elem_type = cutlass::bfloat16_t; \ 35 | return __VA_ARGS__(); \ 36 | } \ 37 | }() 38 | 39 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ 40 | [&] { \ 41 | if (HEADDIM <= 32) { \ 42 | constexpr static int kHeadDim = 32; \ 43 | return __VA_ARGS__(); \ 44 | } else if (HEADDIM <= 64) { \ 45 | constexpr static int kHeadDim = 64; \ 46 | return __VA_ARGS__(); \ 47 | } else if (HEADDIM <= 96) { \ 48 | constexpr static int kHeadDim = 96; \ 49 | return __VA_ARGS__(); \ 50 | } else if (HEADDIM <= 128) { \ 51 | constexpr static int kHeadDim = 128; \ 52 | return __VA_ARGS__(); \ 53 | } else if (HEADDIM <= 160) { \ 54 | constexpr static int kHeadDim = 160; \ 55 | return __VA_ARGS__(); \ 56 | } else if (HEADDIM <= 192) { \ 57 | constexpr static int kHeadDim = 192; \ 58 | return __VA_ARGS__(); \ 59 | } else if (HEADDIM <= 224) { \ 60 | constexpr static int kHeadDim = 224; \ 61 | return __VA_ARGS__(); \ 62 | } else if (HEADDIM <= 256) { \ 63 | constexpr static int kHeadDim = 256; \ 64 | return __VA_ARGS__(); \ 65 | } \ 66 | }() 67 | 68 | 69 | #define WARP_SWITCH(COND, CONST_NAME, ...) \ 70 | [&] { \ 71 | if (COND == 4) { \ 72 | constexpr static int CONST_NAME = 4; \ 73 | return __VA_ARGS__(); \ 74 | } else if (COND == 8) { \ 75 | constexpr static int CONST_NAME = 8; \ 76 | return __VA_ARGS__(); \ 77 | } else { \ 78 | constexpr static int CONST_NAME = 2; \ 79 | return __VA_ARGS__(); \ 80 | } \ 81 | }() 82 | 83 | #define BLOCKM_SWITCH(COND, CONST_NAME, ...) \ 84 | [&] { \ 85 | if (COND == 64) { \ 86 | constexpr static int CONST_NAME = 64; \ 87 | return __VA_ARGS__(); \ 88 | } else if (COND == 128) { \ 89 | constexpr static int CONST_NAME = 128; \ 90 | return __VA_ARGS__(); \ 91 | } else if (COND == 256) { \ 92 | constexpr static int CONST_NAME = 256; \ 93 | return __VA_ARGS__(); \ 94 | } else { \ 95 | constexpr static int CONST_NAME = 64; \ 96 | return __VA_ARGS__(); \ 97 | } \ 98 | }() 99 | 100 | #define BLOCKN_SWITCH(COND, CONST_NAME, ...) \ 101 | [&] { \ 102 | if (COND == 32) { \ 103 | constexpr static int CONST_NAME = 32; \ 104 | return __VA_ARGS__(); \ 105 | } else if (COND == 64) { \ 106 | constexpr static int CONST_NAME = 64; \ 107 | return __VA_ARGS__(); \ 108 | } else if (COND == 128) { \ 109 | constexpr static int CONST_NAME = 128; \ 110 | return __VA_ARGS__(); \ 111 | } else if (COND == 256) { \ 112 | constexpr static int CONST_NAME = 256; \ 113 | return __VA_ARGS__(); \ 114 | } else { \ 115 | constexpr static int CONST_NAME = 64; \ 116 | return __VA_ARGS__(); \ 117 | } \ 118 | }() 119 | 120 | #define STAGE_SWITCH(COND, CONST_NAME, ...) \ 121 | [&] { \ 122 | if (COND == 2) { \ 123 | constexpr static int CONST_NAME = 2; \ 124 | return __VA_ARGS__(); \ 125 | } else if (COND == 3) { \ 126 | constexpr static int CONST_NAME = 3; \ 127 | return __VA_ARGS__(); \ 128 | } else if (COND == 4) { \ 129 | constexpr static int CONST_NAME = 4; \ 130 | return __VA_ARGS__(); \ 131 | } else if (COND == 5) { \ 132 | constexpr static int CONST_NAME = 5; \ 133 | return __VA_ARGS__(); \ 134 | } else { \ 135 | constexpr static int CONST_NAME = 2; \ 136 | return __VA_ARGS__(); \ 137 | } \ 138 | }() 139 | -------------------------------------------------------------------------------- /csrc/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | template 3 | struct MaxOp { 4 | __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } 5 | }; 6 | 7 | template <> 8 | struct MaxOp { 9 | // This is slightly faster 10 | __device__ inline float operator()(float const &x, float const &y) { return max(x, y); } 11 | }; 12 | 13 | //////////////////////////////////////////////////////////////////////////////////////////////////// 14 | 15 | template 16 | struct SumOp { 17 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 18 | }; 19 | 20 | //////////////////////////////////////////////////////////////////////////////////////////////////// 21 | 22 | template 23 | struct Allreduce { 24 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 25 | template 26 | static __device__ inline T run(T x, Operator &op) { 27 | constexpr int OFFSET = THREADS / 2; 28 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 29 | return Allreduce::run(x, op); 30 | } 31 | }; 32 | 33 | //////////////////////////////////////////////////////////////////////////////////////////////////// 34 | 35 | template<> 36 | struct Allreduce<2> { 37 | template 38 | static __device__ inline T run(T x, Operator &op) { 39 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 40 | return x; 41 | } 42 | }; 43 | 44 | // tensor:((2, MMA_M),(2, MMA_N)) 45 | // summary:(2 * MMA_N) 46 | template 47 | __device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { 48 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 49 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 50 | CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); 51 | #pragma unroll 52 | for (int mi = 0; mi < size<0>(tensor); mi++) { 53 | summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); 54 | #pragma unroll 55 | for (int ni = 1; ni < size<1>(tensor); ni++) { 56 | summary(mi) = op(summary(mi), tensor(mi, ni)); 57 | } 58 | } 59 | } 60 | 61 | // summary:(2 * MMA_N) 62 | // summary:(2 * MMA_N) 63 | template 64 | __device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { 65 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 66 | #pragma unroll 67 | for (int i = 0; i < size(dst); i++){ 68 | // NOTE: 4表示4个线程, 因为在SM80_16x8x16_F32F16F16F32_TN中, 69 | // 每组每行就是4个线程处理8个value的, 每个线程处理2个value, 70 | dst(i) = Allreduce<4>::run(src(i), op); 71 | } 72 | } 73 | 74 | // tensor:((2, MMA_M),(2, MMA_N)) 75 | // summary:(2 * MMA_N) 76 | template 77 | __device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { 78 | // NOTE: 遍历tensor每行, 记录到summary中 79 | // reduce 当前thread的max 80 | thread_reduce_(tensor, summary, op); 81 | // NOTE: 二分法对summary[]进行reduce 82 | // reduce thread间的max 83 | quad_allreduce_(summary, summary, op); 84 | } 85 | 86 | // scores:((2, MMA_M),(2, MMA_N)) 87 | // scores_max:(2 * MMA_N) 88 | template 89 | __device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ 90 | MaxOp max_op; 91 | reduce_(tensor, max, max_op); 92 | } 93 | 94 | template 95 | __device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ 96 | SumOp sum_op; 97 | reduce_(tensor, sum, sum_op); 98 | } 99 | 100 | -------------------------------------------------------------------------------- /include/attention_api.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // cuda header file that use nvcc to compile, which can recognize the cuda 4 | // keyword like __global__ and __device__ 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | // NOTE:tensor malloc as device before we call 11 | // e.g. data.to("cuda") in python 12 | #define CHECK_CUDA(x) \ 13 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | #define CUDA_ERROR_CHECK(condition) \ 21 | do { \ 22 | cudaError_t error = condition; \ 23 | if (error != cudaSuccess) { \ 24 | printf("CUDA_CHECK error in line %d of file %s \ 25 | : %s \n", \ 26 | __LINE__, __FILE__, cudaGetErrorString(error)); \ 27 | exit(EXIT_FAILURE); \ 28 | } \ 29 | } while (0) 30 | 31 | -------------------------------------------------------------------------------- /include/attention_api.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "flash.h" 9 | 10 | std::vector flash_attention_v2_cutlass(torch::Tensor q, torch::Tensor k, 11 | torch::Tensor v, bool is_causal = false, float softmax_scale=1); 12 | 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from packaging.version import parse, Version 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import ( 7 | BuildExtension, 8 | CUDAExtension, 9 | CUDA_HOME, 10 | ) 11 | 12 | # package name managed by pip, which can be remove by `pip uninstall tiny_pkg` 13 | PACKAGE_NAME = "tiny_attention_cutlass" 14 | 15 | ext_modules = [] 16 | generator_flag = [] 17 | cc_flag = [] 18 | cc_flag.append("-gencode") 19 | cc_flag.append("arch=compute_80,code=sm_80") 20 | 21 | 22 | # helper function to get cuda version 23 | def get_cuda_bare_metal_version(cuda_dir): 24 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 25 | output = raw_output.split() 26 | release_idx = output.index("release") + 1 27 | bare_metal_version = parse(output[release_idx].split(",")[0]) 28 | 29 | return raw_output, bare_metal_version 30 | 31 | 32 | if CUDA_HOME is not None: 33 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 34 | if bare_metal_version >= Version("11.8"): 35 | cc_flag.append("-gencode") 36 | cc_flag.append("arch=compute_90,code=sm_90") 37 | 38 | # ninja build does not work unless include_dirs are abs path 39 | this_dir = os.path.dirname(os.path.abspath(__file__)) 40 | 41 | # cuda module 42 | ext_modules.append( 43 | CUDAExtension( 44 | # package name for import 45 | name="attention_cutlass", 46 | sources=[ 47 | "csrc/attention_api.cpp", 48 | "csrc/flash_attention.cu", 49 | "csrc/flash_api.cpp", 50 | ], 51 | extra_compile_args={ 52 | # add c compile flags 53 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 54 | # add nvcc compile flags 55 | "nvcc": [ 56 | "-O3", 57 | "-std=c++17", 58 | "-U__CUDA_NO_HALF_OPERATORS__", 59 | "--use_fast_math", 60 | "-lineinfo", 61 | "--ptxas-options=-v", 62 | "--ptxas-options=-O2", 63 | "-U__CUDA_NO_HALF_OPERATORS__", 64 | "-U__CUDA_NO_HALF_CONVERSIONS__", 65 | "-U__CUDA_NO_HALF2_OPERATORS__", 66 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 67 | "--expt-relaxed-constexpr", 68 | "--expt-extended-lambda", 69 | "--use_fast_math", 70 | 71 | ] 72 | + generator_flag 73 | + cc_flag, 74 | }, 75 | include_dirs=[ 76 | Path(this_dir) / "csrc", 77 | Path(this_dir) / "include", 78 | Path(this_dir) / "deps/cutlass/include", 79 | Path(this_dir) / "deps/cutlass/tools/utils/include" , 80 | Path(this_dir) / "deps/cutlass/examples/common" , 81 | # Path(this_dir) / "some" / "thing" / "more", 82 | ], 83 | ) 84 | ) 85 | 86 | setup( 87 | name=PACKAGE_NAME, 88 | packages=find_packages( 89 | exclude=( 90 | "build", 91 | "csrc", 92 | "include", 93 | "tests", 94 | "dist", 95 | "docs", 96 | "benchmarks", 97 | ) 98 | ), 99 | description="Attention mechanism implement by CUDA", 100 | ext_modules=ext_modules, 101 | cmdclass={ "build_ext": BuildExtension}, 102 | python_requires=">=3.7", 103 | install_requires=[ 104 | "torch", 105 | "einops", 106 | "packaging", 107 | "ninja", 108 | ], 109 | ) 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from attention_cutlass import flash_attention_v2_cutlass 4 | import math 5 | import time 6 | # offical flash attention implement 7 | from vllm_flash_attn import flash_attn_func as flash_attn_func_offical 8 | 9 | ''' 10 | simple attention implement without multi head 11 | ''' 12 | 13 | torch.manual_seed(180) 14 | def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16): 15 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 16 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 17 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 18 | return q, k, v 19 | 20 | def self_attention(q, k, v, causal=True, sm_scale=1): 21 | SEQLEN = q.shape[-2] 22 | M = torch.tril(torch.ones((SEQLEN, SEQLEN), device="cuda")) 23 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 24 | if causal: 25 | p[:, :, M == 0] = float("-inf") 26 | p = torch.softmax(p.float(), dim=-1).half() 27 | ref_out = torch.matmul(p, v) 28 | return ref_out 29 | 30 | 31 | def run_benchmark(epoch, warmup, func, *args, **kwargs): 32 | # warmup phase 33 | for _ in range(warmup): 34 | _ = func(*args, **kwargs) 35 | torch.cuda.synchronize() 36 | time_s = time.time() 37 | for _ in range(epoch): 38 | _ = func(*args, **kwargs) 39 | torch.cuda.synchronize() 40 | time_e = time.time() - time_s 41 | return time_e 42 | 43 | 44 | def main(bs=1, head=64, seq_len=4096, dim=64): 45 | BS, HEAD, SEQLEN, DIM = bs, head, seq_len, dim 46 | q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16) 47 | 48 | warmup = 5 49 | epoch = 20 50 | 51 | is_causal = True 52 | sm_scale = 1.0 / math.sqrt(SEQLEN) 53 | 54 | 55 | base_time = run_benchmark(epoch, warmup, self_attention, q, k, v, causal=is_causal, sm_scale=sm_scale) 56 | baseline = self_attention(q, k, v, causal=is_causal, sm_scale=sm_scale) 57 | 58 | flash2_time = run_benchmark(epoch, warmup, flash_attention_v2_cutlass, q, k, v, is_causal, sm_scale) 59 | flash2_cutlass_ref = flash_attention_v2_cutlass(q, k, v, is_causal, sm_scale)[0] 60 | 61 | fq = q.transpose(1, 2) 62 | fk = k.transpose(1, 2) 63 | fv = v.transpose(1, 2) 64 | official_ref_time = run_benchmark(epoch, warmup, flash_attn_func_offical, fq, fk, fv, causal=is_causal, softmax_scale=sm_scale) 65 | official_result = flash_attn_func_offical(fq, fk, fv, causal=is_causal, softmax_scale=sm_scale) 66 | 67 | print(f"bs:{bs}, head:{head}, seq_len:{seq_len}, dim:{dim} \ 68 | baseline:{base_time * 1000 / epoch} ms \ 69 | flash2_cutlass_fp16:{official_ref_time * 1000 / epoch} ms") 70 | 71 | assert torch.allclose(baseline, flash2_cutlass_ref, rtol=0, atol=1e-2) 72 | 73 | 74 | if __name__ == "__main__": 75 | epoch = 1 76 | for _ in range(epoch): 77 | for bs in [1, 2]: 78 | for head in [8, 16, 32]: 79 | for seq_len in [64, 1024, 4096]: 80 | for dim in [32, 64]: 81 | 82 | main(bs, head, seq_len, dim) 83 | 84 | 85 | --------------------------------------------------------------------------------