├── flash_attention_cutlass ├── deps │ └── .gitignore ├── .gitignore ├── README.md ├── csrc │ ├── flash_api.cpp │ ├── attention_api.cpp │ ├── flash.h │ ├── utils.h │ ├── static_switch.h │ └── kernel_traits.h ├── makefile ├── include │ ├── attention_api.h │ └── attention_api.cuh ├── CMakeLists.txt ├── standalone_src │ ├── flash.h │ ├── utils.h │ └── kernel_traits.h ├── test.py └── build.py ├── flash_attention_py ├── .gitignore ├── makefile ├── main.py ├── tiny_flash_attn.py └── tiny_flash_attn_triton.py ├── flash_attention_cuda ├── makefile ├── .gitignore ├── standalone_src │ ├── makefile │ ├── helper.h │ ├── self_attention_standalone.cu │ ├── flash_attention_v2_standalone.cu │ └── flash_attention_v1_standalone.cu ├── include │ ├── attention_api.h │ └── attention_api.cuh ├── .clang-format ├── csrc │ ├── attention_api.cpp │ ├── static_switch.h │ ├── self_attention.cu │ └── flash_attention.cu ├── README.md ├── build.py ├── self_attention.py └── flash_attn_triton.py ├── .gitmodules ├── flash_attention_c ├── makefile ├── csrc │ ├── attn.h │ ├── ops.cu │ ├── ops.h │ ├── utils.h │ └── attn.cpp ├── CMakeLists.txt └── test.py ├── README.md ├── LICENSE ├── README_zh.md ├── cutlass_cute_tutorial_zh.md └── cutlass_cute_tutorial_en.md /flash_attention_cutlass/deps/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attention_py/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /flash_attention_py/makefile: -------------------------------------------------------------------------------- 1 | test: 2 | pytest -s tiny_flash_attn_triton.py 3 | -------------------------------------------------------------------------------- /flash_attention_cuda/makefile: -------------------------------------------------------------------------------- 1 | build: 2 | python build.py install 3 | 4 | .PHONY: build 5 | -------------------------------------------------------------------------------- /flash_attention_cutlass/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | .cache/ 3 | dist/ 4 | tiny_attention_cutlass.egg-info/ 5 | -------------------------------------------------------------------------------- /flash_attention_cutlass/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ``` 4 | # build standalone 5 | make 6 | 7 | # build python binding 8 | make build 9 | ``` 10 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/flash_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_api.h" 5 | #include "flash.h" 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "flash_attention_cutlass/deps/cutlass"] 2 | path = flash_attention_cutlass/deps/cutlass 3 | url = https://github.com/NVIDIA/cutlass 4 | branch = main 5 | -------------------------------------------------------------------------------- /flash_attention_cuda/.gitignore: -------------------------------------------------------------------------------- 1 | flash_attn_org/ 2 | tiny_attention_cuda.egg-info/ 3 | .clangd 4 | tiny_flash_attn.egg-info/ 5 | pytorch-cuda-binding-tutorial 6 | build/ 7 | dist/ 8 | -------------------------------------------------------------------------------- /flash_attention_c/makefile: -------------------------------------------------------------------------------- 1 | build: 2 | cmake -B build -D CMAKE_EXPORT_COMPILE_COMMANDS=ON \ 3 | -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \ 4 | -Dpybind11_DIR=`pybind11-config --cmakedir` -DCMAKE_BUILD_TYPE=DEBUG 5 | cmake --build build 6 | 7 | .PHONY: build 8 | -------------------------------------------------------------------------------- /flash_attention_cutlass/makefile: -------------------------------------------------------------------------------- 1 | run_alone: build_standalone 2 | ./build/flash_attention_cutlass_standalone 3 | 4 | build_standalone: 5 | cmake -B build 6 | cmake --build build 7 | 8 | build: 9 | python build.py install 10 | 11 | dbg: build_standalone 12 | cuda-gdb ./build/main 13 | 14 | .PHONY: build 15 | 16 | -------------------------------------------------------------------------------- /flash_attention_c/csrc/attn.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | template 10 | struct Naive_fwd_traits { 11 | using elem_type = T; 12 | }; 13 | -------------------------------------------------------------------------------- /flash_attention_c/csrc/ops.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ops.h" 3 | 4 | PYBIND11_MODULE(_kernels, m) { 5 | m.def("hello_world", &hello_world, "hello_world placeholder"); 6 | m.def("naive_attn", &naive_attn, "Naive attention implementation on CPU"); 7 | m.def("flash_attn", &flash_attn, "Flash attention implementation on CPU"); 8 | } 9 | 10 | -------------------------------------------------------------------------------- /flash_attention_cutlass/include/attention_api.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "flash.h" 9 | 10 | std::vector flash_attention_v2_cutlass(torch::Tensor q, torch::Tensor k, 11 | torch::Tensor v, bool is_causal = false, float softmax_scale=1); 12 | 13 | -------------------------------------------------------------------------------- /flash_attention_cuda/standalone_src/makefile: -------------------------------------------------------------------------------- 1 | run: build 2 | ./self_attention_standalone 3 | ./flash_attention_v1_standalone 4 | ./flash_attention_v2_standalone 5 | 6 | build: 7 | nvcc -o self_attention_standalone self_attention_standalone.cu 8 | nvcc -o flash_attention_v1_standalone flash_attention_v1_standalone.cu 9 | nvcc -o flash_attention_v2_standalone flash_attention_v2_standalone.cu 10 | 11 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/attention_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_api.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // m.def("package_name", &function_name, "function_docstring"") 8 | m.def("flash_attention_v2_cutlass", &flash_attention_v2_cutlass, 9 | "Flash attention v2 implement in cutlass"); 10 | } 11 | -------------------------------------------------------------------------------- /flash_attention_cuda/include/attention_api.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | torch::Tensor self_attention_cuda(torch::Tensor q, torch::Tensor k, torch::Tensor v); 8 | torch::Tensor flash_attention_v1_cuda(torch::Tensor q, torch::Tensor k, torch::Tensor v); 9 | torch::Tensor flash_attention_v2_cuda(torch::Tensor q, torch::Tensor k, torch::Tensor v); 10 | 11 | -------------------------------------------------------------------------------- /flash_attention_c/csrc/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void hello_world() { 6 | std::cout << "Hello, World!" << std::endl; 7 | } 8 | 9 | // Naive attention impl 10 | torch::Tensor naive_attn(torch::Tensor q, torch::Tensor k, torch::Tensor v, bool is_causal = false, float softmax_scale=1); 11 | 12 | // Flash attention impl 13 | torch::Tensor flash_attn(torch::Tensor q, torch::Tensor k, torch::Tensor v, bool is_causal = false, float softmax_scale=1); 14 | 15 | -------------------------------------------------------------------------------- /flash_attention_cuda/.clang-format: -------------------------------------------------------------------------------- 1 | # BasedOnStyle: Google 2 | # DerivePointerAlignment: false 3 | # PointerAlignment: Right 4 | # ColumnLimit: 120 5 | 6 | # # Default for clang-8, changed in later clangs. Set explicitly for forwards 7 | # # compatibility for students with modern clangs 8 | # IncludeBlocks: Preserve 9 | --- 10 | # We'll use defaults from the LLVM style, but with 4 columns indentation. 11 | BasedOnStyle: LLVM 12 | IndentWidth: 4 13 | Language: Cpp 14 | # # Force pointers to the type for C++. 15 | DerivePointerAlignment: false 16 | PointerAlignment: Left 17 | -------------------------------------------------------------------------------- /flash_attention_cuda/csrc/attention_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_api.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // m.def("package_name", &function_name, "function_docstring"") 8 | m.def("self_attention_cuda", &self_attention_cuda, 9 | "Naive Self attention implement in cuda"); 10 | m.def("flash_attention_v1_cuda", &flash_attention_v1_cuda, 11 | "Flash attention v1 implement in cuda"); 12 | m.def("flash_attention_v2_cuda", &flash_attention_v2_cuda, 13 | "Flash attention v2 implement in cuda"); 14 | } 15 | -------------------------------------------------------------------------------- /flash_attention_cuda/standalone_src/helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | bool all_close(float *A, float *B, int m, int n) { 4 | for (int i = 0; i < m * n; i++) { 5 | if (fabs(A[i] - B[i]) > 1e-5) { 6 | printf("A[%d] = %f, B[%d] = %f\n", i, A[i], i, B[i]); 7 | return false; 8 | } 9 | } 10 | return true; 11 | } 12 | 13 | // print matrix 14 | void print_host_matrix(float *matrix, int m, int n) { 15 | for (int i = 0; i < m; i++) { 16 | for (int j = 0; j < n; j++) { 17 | printf("%f, ", matrix[i * n + j]); 18 | } 19 | printf("\n"); 20 | } 21 | } 22 | 23 | void print_device_matrix(float *dev_ptr, int m, int n) { 24 | float *host_ptr = new float[m * n]; 25 | cudaMemcpy(host_ptr, dev_ptr, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 26 | 27 | for (int i = 0; i < m; i++) { 28 | for (int j = 0; j < n; j++) { 29 | printf("%f, ", host_ptr[i * n + j]); 30 | } 31 | printf("\n"); 32 | } 33 | free(host_ptr); 34 | } 35 | 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tiny FlashAttention 2 | 3 | WIP 4 | 5 | A tiny [flash attention](https://github.com/Dao-AILab/flash-attention) implement in python, rust, cuda and c for learning purpose. 6 | 7 | - [python version](#flash-attention-2) 8 | * [x] [naive pure python code](./flash_attention_py/tiny_flash_attn.py) 9 | - [triton version](#triton-flash-attention-2) 10 | * [x] [triton code](./flash_attention_py/tiny_flash_attn_triton.py) 11 | - [c version] 12 | * [x] [naive pure c code](./flash_attention_c/csrc/attn.cpp) 13 | * [x] [naive cuda code standalone](./flash_attention_cuda/standalone_src) 14 | * [x] [naive cuda code python binding](./flash_attention_cutlass/csrc/flash_attention.cu) 15 | * [x] [cutlass cuda code](./flash_attention_cutlass/csrc/flash_attention.cu) 16 | - [rust version] 17 | 18 | ## cutlass cute flash attention in action 19 | 20 | my env: cutlass v3.4, torch 1.14, cuda 12.4 21 | 22 | - [en tutorial](./cutlass_cute_tutorial_en.md) 23 | - [zh tutorial](./cutlass_cute_tutorial_zh.md) 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /flash_attention_cuda/README.md: -------------------------------------------------------------------------------- 1 | # flash attention implementin CUDA 2 | 3 | NOTE: specific pytorch version require to support the deserted API. Just use the standalone version or CUTLASS version. 4 | 5 | ## roadmap 6 | 7 | - [x] naive self attention python 8 | - [x] naive self attention cuda 9 | - [x] naive self attention python API binding 10 | - TODO: 11 | * half support 12 | * make template data type more general 13 | * thread balance and too many thread may cause crash 14 | * clean deprecated warning 15 | - [x] flash attention 1 cuda 16 | - [x] flash attention 2 cuda 17 | - [x] flash attention 1/2 python binding 18 | - [ ] split template and more general template(like dim and block size) 19 | - [x] MHA support 20 | - [ ] causal mode support 21 | - [ ] flash attention cute 22 | - [x] checkout `static_switch.h` in flash attention 23 | 24 | 25 | ## result 26 | 27 | - You need **result-oriented programming** in CUDA 28 | * e.g. for `C[x, y]` should from thread (x, y) 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /flash_attention_c/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26.4) # Specify your minimum CMake version 2 | 3 | set(CMAKE_C_STANDARD 17) 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CUDA_STANDARD 17) 6 | set(CMAKE_CUDA_ARCHITECTURES 89) 7 | 8 | project(_kernels LANGUAGES CUDA CXX) 9 | 10 | find_package(Python REQUIRED COMPONENTS Interpreter Development) 11 | find_package(Torch REQUIRED) 12 | find_package(pybind11 REQUIRED) 13 | find_package(OpenMP REQUIRED) 14 | find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") 15 | 16 | file(GLOB PYTORCH_SOURCES "csrc/*.cu" "csrc/*.c" "csrc/*.cpp") 17 | pybind11_add_module(_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) 18 | 19 | target_compile_definitions(_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check 20 | target_compile_options(_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) 21 | target_link_libraries(_kernels PRIVATE ${TORCH_LIBRARIES} Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY} OpenMP::OpenMP_CXX) 22 | 23 | -------------------------------------------------------------------------------- /flash_attention_cutlass/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26) 2 | project(cutlass CUDA CXX) 3 | 4 | # set environment PATH for cmake 5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 6 | set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) 7 | set(CMAKE_COLOR_DIAGNOSTICS ON) 8 | 9 | find_package(CUDAToolkit REQUIRED) 10 | include_directories( 11 | "deps/cutlass/include" 12 | "include" 13 | ) 14 | 15 | file(GLOB CUDA_SOURCE_FILES "./standalone_src/*.cu") 16 | foreach(CUDA_SOURCE_FILE ${CUDA_SOURCE_FILES}) 17 | # NOTE: NAME_WE: name without extension 18 | # Extract the filename ${CUDA_SOURCE_FILE} without the extension to EXECUTABLE_NAME 19 | get_filename_component(EXECUTABLE_NAME ${CUDA_SOURCE_FILE} NAME_WE) 20 | 21 | # Create an executable for each source file 22 | add_executable(${EXECUTABLE_NAME} ${CUDA_SOURCE_FILE}) 23 | set_target_properties(${EXECUTABLE_NAME} PROPERTIES CXX_STANDARD 17 CUDA_ARCHITECTURES 80) 24 | # target_compile_options(${EXECUTABLE_NAME} PRIVATE -G -g) 25 | target_compile_options(${EXECUTABLE_NAME} PRIVATE -O3 -lineinfo) 26 | endforeach() 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 66RING 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /flash_attention_cutlass/standalone_src/flash.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // TODO: 特种约束字段, e.g. __restrict__ 的效果 6 | struct Qkv_params { 7 | using index_t = uint32_t; 8 | // The QKV matrices. 9 | void *__restrict__ q_ptr; 10 | void *__restrict__ k_ptr; 11 | void *__restrict__ v_ptr; 12 | 13 | // // The stride between rows of the Q, K and V matrices. 14 | // index_t q_batch_stride; 15 | // index_t k_batch_stride; 16 | // index_t v_batch_stride; 17 | // // TODO: 18 | // index_t q_row_stride; 19 | // index_t k_row_stride; 20 | // index_t v_row_stride; 21 | // index_t q_head_stride; 22 | // index_t k_head_stride; 23 | // index_t v_head_stride; 24 | 25 | // The number of heads. 26 | int h, h_k; 27 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 28 | // different from nheads (query). 29 | int h_h_k_ratio; // precompute h / h_k, 30 | 31 | bool is_bf16 = false; 32 | }; 33 | 34 | 35 | struct Flash_fwd_params : public Qkv_params { 36 | size_t bs; 37 | size_t head; 38 | size_t seqlen; 39 | size_t dim; 40 | 41 | size_t bs_stride; 42 | size_t head_stride; 43 | size_t seqlen_stride; 44 | size_t dim_stride; 45 | 46 | float softmax_scale; 47 | void *__restrict__ out_ptr; 48 | 49 | bool is_causal; 50 | }; 51 | 52 | -------------------------------------------------------------------------------- /flash_attention_cuda/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ 4 | [&] { \ 5 | if (HEADDIM <= 32) { \ 6 | constexpr static int kHeadDim = 32; \ 7 | return __VA_ARGS__(); \ 8 | } else if (HEADDIM <= 64) { \ 9 | constexpr static int kHeadDim = 64; \ 10 | return __VA_ARGS__(); \ 11 | } else if (HEADDIM <= 96) { \ 12 | constexpr static int kHeadDim = 96; \ 13 | return __VA_ARGS__(); \ 14 | } else if (HEADDIM <= 128) { \ 15 | constexpr static int kHeadDim = 128; \ 16 | return __VA_ARGS__(); \ 17 | } else if (HEADDIM <= 160) { \ 18 | constexpr static int kHeadDim = 160; \ 19 | return __VA_ARGS__(); \ 20 | } else if (HEADDIM <= 192) { \ 21 | constexpr static int kHeadDim = 192; \ 22 | return __VA_ARGS__(); \ 23 | } else if (HEADDIM <= 224) { \ 24 | constexpr static int kHeadDim = 224; \ 25 | return __VA_ARGS__(); \ 26 | } else if (HEADDIM <= 256) { \ 27 | constexpr static int kHeadDim = 256; \ 28 | return __VA_ARGS__(); \ 29 | } \ 30 | }() 31 | 32 | -------------------------------------------------------------------------------- /flash_attention_c/csrc/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define DEBUG 1 4 | 5 | #ifdef DEBUG 6 | 7 | // NOTE:tensor malloc as device before we call 8 | // e.g. data.to("cuda") in python 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | #define CUDA_ERROR_CHECK(condition) \ 13 | do { \ 14 | cudaError_t error = condition; \ 15 | if (error != cudaSuccess) { \ 16 | printf("CUDA_CHECK error in line %d of file %s \ 17 | : %s \n", \ 18 | __LINE__, __FILE__, cudaGetErrorString(error)); \ 19 | exit(EXIT_FAILURE); \ 20 | } \ 21 | } while (0) 22 | 23 | #else // ifdef DEBUG 24 | 25 | #define CHECK_CUDA(x) do { } while (0) 26 | #define CHECK_CONTIGUOUS(x) do { } while (0) 27 | #define CHECK_INPUT(x) do { } while (0) 28 | #define CUDA_ERROR_CHECK(condition) do { condition; } while (0) 29 | 30 | #endif // !ifdef DEBUG 31 | 32 | 33 | -------------------------------------------------------------------------------- /flash_attention_cutlass/include/attention_api.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // cuda header file that use nvcc to compile, which can recognize the cuda 4 | // keyword like __global__ and __device__ 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | // NOTE:tensor malloc as device before we call 11 | // e.g. data.to("cuda") in python 12 | #define CHECK_CUDA(x) \ 13 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | #define CUDA_ERROR_CHECK(condition) \ 21 | do { \ 22 | cudaError_t error = condition; \ 23 | if (error != cudaSuccess) { \ 24 | printf("CUDA_CHECK error in line %d of file %s \ 25 | : %s \n", \ 26 | __LINE__, __FILE__, cudaGetErrorString(error)); \ 27 | exit(EXIT_FAILURE); \ 28 | } \ 29 | } while (0) 30 | 31 | -------------------------------------------------------------------------------- /flash_attention_cuda/include/attention_api.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // cuda header file that use nvcc to compile, which can recognize the cuda 4 | // keyword like __global__ and __device__ 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | // NOTE:tensor malloc as device before we call 11 | // e.g. data.to("cuda") in python 12 | #define CHECK_CUDA(x) \ 13 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | #define CUDA_ERROR_CHECK(condition) \ 21 | do { \ 22 | cudaError_t error = condition; \ 23 | if (error != cudaSuccess) { \ 24 | printf("CUDA_CHECK error in line %d of file %s \ 25 | : %s \n", \ 26 | __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ 27 | exit(EXIT_FAILURE); \ 28 | } \ 29 | } while (0) 30 | 31 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/flash.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // TODO: 特种约束字段, e.g. __restrict__ 的效果 6 | struct Qkv_params { 7 | using index_t = uint32_t; 8 | // The QKV matrices. 9 | void *__restrict__ q_ptr; 10 | void *__restrict__ k_ptr; 11 | void *__restrict__ v_ptr; 12 | 13 | // // The stride between rows of the Q, K and V matrices. 14 | // index_t q_batch_stride; 15 | // index_t k_batch_stride; 16 | // index_t v_batch_stride; 17 | // // TODO: 18 | // index_t q_row_stride; 19 | // index_t k_row_stride; 20 | // index_t v_row_stride; 21 | // index_t q_head_stride; 22 | // index_t k_head_stride; 23 | // index_t v_head_stride; 24 | 25 | bool is_bf16; 26 | }; 27 | 28 | 29 | struct Flash_fwd_params : public Qkv_params { 30 | size_t bs; 31 | size_t head; 32 | size_t q_seqlen; 33 | size_t dim; 34 | 35 | size_t k_head; 36 | size_t k_seqlen; 37 | 38 | // TODO: review the impl of flash 39 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 40 | // different from nheads (query). 41 | size_t h_h_k_ratio; // precompute head / k_head, 42 | size_t flat_seqlen; 43 | size_t kv_head_stride; 44 | size_t qo_head_stride; 45 | 46 | 47 | size_t bs_stride; 48 | size_t head_stride; 49 | size_t seqlen_stride; 50 | size_t dim_stride; 51 | 52 | float softmax_scale; 53 | float softmax_scale_log2; 54 | void *__restrict__ out_ptr; 55 | void *__restrict__ softmax_lse_ptr; 56 | void *__restrict__ score_max; 57 | void *__restrict__ score_sum; 58 | 59 | bool is_causal; 60 | }; 61 | 62 | -------------------------------------------------------------------------------- /flash_attention_c/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import math 4 | from flash_attn import flash_attn_func 5 | 6 | from build._kernels import naive_attn, flash_attn 7 | 8 | @torch.inference_mode() 9 | def ref_attn(q, k, v, causal=True, sm_scale=1): 10 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 11 | qlen = q.shape[-2] 12 | klen = k.shape[-2] 13 | if causal: 14 | gap = klen - qlen 15 | for i in range(qlen): 16 | p[:, :, i, i + gap + 1:] = float("-inf") 17 | p = torch.softmax(p.float(), dim=-1).to(q.dtype) 18 | ref_out = torch.matmul(p, v) 19 | return ref_out 20 | 21 | def run_benchmark(epoch, warmup, func, *args, **kwargs): 22 | # warmup phase 23 | for _ in range(warmup): 24 | _ = func(*args, **kwargs) 25 | torch.cuda.synchronize() 26 | time_s = time.time() 27 | for _ in range(epoch): 28 | _ = func(*args, **kwargs) 29 | torch.cuda.synchronize() 30 | time_e = time.time() - time_s 31 | return time_e 32 | 33 | 34 | def main(): 35 | torch.manual_seed(0) 36 | 37 | bs, head_num, seqlen, head_dim = 3, 32, 128, 128 38 | q = torch.rand(bs, head_num, seqlen, head_dim, dtype=torch.float32, device="cpu") 39 | k = torch.rand(bs, head_num, seqlen, head_dim, dtype=torch.float32, device="cpu") 40 | v = torch.rand(bs, head_num, seqlen, head_dim, dtype=torch.float32, device="cpu") 41 | is_causal = True 42 | softmax_scale = 1 / math.sqrt(head_dim) 43 | 44 | warmup = 10 45 | epoch = 10 46 | 47 | naive_out = naive_attn(q, k, v, is_causal, softmax_scale) 48 | fa_out = flash_attn(q, k, v, is_causal, softmax_scale) 49 | 50 | naive_time = run_benchmark(epoch, warmup, naive_attn, q, k, v, is_causal, softmax_scale) 51 | fa_time = run_benchmark(epoch, warmup, flash_attn, q, k, v, is_causal, softmax_scale) 52 | 53 | # warmup 54 | q = q.to("cuda") 55 | k = k.to("cuda") 56 | v = v.to("cuda") 57 | ref_time = run_benchmark(epoch, warmup, ref_attn, q, k, v, is_causal, softmax_scale) 58 | ref_out = ref_attn(q, k, v, is_causal, softmax_scale).cpu() 59 | 60 | 61 | q = q.to(torch.bfloat16) 62 | k = k.to(torch.bfloat16) 63 | v = v.to(torch.bfloat16) 64 | q = q.transpose(1, 2) 65 | k = k.transpose(1, 2) 66 | v = v.transpose(1, 2) 67 | fa_ref = flash_attn_func(q, k, v, causal=is_causal, softmax_scale=softmax_scale).transpose(1, 2) 68 | fa_ref_time = run_benchmark(epoch, warmup, flash_attn_func, q, k, v, causal=is_causal, softmax_scale=softmax_scale) 69 | 70 | print(f"naive CPU time: {naive_time:.3f} s") 71 | print(f"flash CPU time: {fa_time:.3f} s") 72 | print(f"torch NAIVE time: {ref_time:.3f} s") 73 | 74 | print(f"naive_out: {naive_out} {naive_out.shape}") 75 | print("-----") 76 | print(f"fa_out: {fa_out} {fa_out.shape}") 77 | print("-----") 78 | print(f"fa_ref: {fa_ref} {fa_ref.shape}") 79 | print("-----") 80 | print(f"ref_out: {ref_out} {ref_out.shape}") 81 | 82 | assert torch.allclose(fa_out, ref_out, atol=1e-2) 83 | assert torch.allclose(naive_out, ref_out, atol=1e-2) 84 | 85 | 86 | 87 | 88 | 89 | pass 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /flash_attention_cuda/build.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from packaging.version import parse, Version 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import ( 7 | BuildExtension, 8 | CUDAExtension, 9 | CUDA_HOME, 10 | ) 11 | 12 | # package name managed by pip, which can be remove by `pip uninstall tiny_pkg` 13 | PACKAGE_NAME = "tiny_attention_cuda" 14 | 15 | ext_modules = [] 16 | generator_flag = [] 17 | cc_flag = [] 18 | cc_flag.append("-gencode") 19 | cc_flag.append("arch=compute_80,code=sm_80") 20 | 21 | 22 | # helper function to get cuda version 23 | def get_cuda_bare_metal_version(cuda_dir): 24 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 25 | output = raw_output.split() 26 | release_idx = output.index("release") + 1 27 | bare_metal_version = parse(output[release_idx].split(",")[0]) 28 | 29 | return raw_output, bare_metal_version 30 | 31 | 32 | if CUDA_HOME is not None: 33 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 34 | if bare_metal_version >= Version("11.8"): 35 | cc_flag.append("-gencode") 36 | cc_flag.append("arch=compute_90,code=sm_90") 37 | 38 | # ninja build does not work unless include_dirs are abs path 39 | this_dir = os.path.dirname(os.path.abspath(__file__)) 40 | 41 | # cuda module 42 | ext_modules.append( 43 | CUDAExtension( 44 | # package name for import 45 | name="attention_cuda", 46 | sources=[ 47 | "csrc/attention_api.cpp", 48 | "csrc/flash_attention.cu", 49 | "csrc/self_attention.cu", 50 | ], 51 | extra_compile_args={ 52 | # add c compile flags 53 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 54 | # add nvcc compile flags 55 | "nvcc": [ 56 | "-O3", 57 | "-std=c++17", 58 | # TODO: half支持是因为这个? 59 | "-U__CUDA_NO_HALF_OPERATORS__", 60 | "--use_fast_math", 61 | "-lineinfo", 62 | # "--ptxas-options=-v", 63 | # "--ptxas-options=-O2", 64 | ] 65 | + generator_flag 66 | + cc_flag, 67 | }, 68 | include_dirs=[ 69 | Path(this_dir) / "csrc", 70 | Path(this_dir) / "include", 71 | # Path(this_dir) / "some" / "thing" / "more", 72 | ], 73 | ) 74 | ) 75 | 76 | setup( 77 | name=PACKAGE_NAME, 78 | packages=find_packages( 79 | exclude=( 80 | "build", 81 | "csrc", 82 | "include", 83 | "tests", 84 | "dist", 85 | "docs", 86 | "benchmarks", 87 | ) 88 | ), 89 | description="Attention mechanism implement by CUDA", 90 | ext_modules=ext_modules, 91 | cmdclass={ "build_ext": BuildExtension}, 92 | python_requires=">=3.7", 93 | install_requires=[ 94 | "torch", 95 | "einops", 96 | "packaging", 97 | "ninja", 98 | ], 99 | ) 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /flash_attention_cutlass/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from attention_cutlass import flash_attention_v2_cutlass 4 | import math 5 | import time 6 | # offical flash attention implement 7 | from flash_attn import flash_attn_func as flash_attn_func_offical 8 | 9 | ''' 10 | simple attention implement without multi head 11 | ''' 12 | 13 | def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16): 14 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 15 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 16 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 17 | return q, k, v 18 | 19 | def self_attention(q, k, v, causal=True, sm_scale=1): 20 | SEQLEN = q.shape[-2] 21 | M = torch.tril(torch.ones((SEQLEN, SEQLEN), device="cuda")) 22 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 23 | if causal: 24 | p[:, :, M == 0] = float("-inf") 25 | p = torch.softmax(p.float(), dim=-1).half() 26 | ref_out = torch.matmul(p, v) 27 | return ref_out 28 | 29 | 30 | def run_benchmark(epoch, warmup, func, *args, **kwargs): 31 | # warmup phase 32 | for _ in range(warmup): 33 | _ = func(*args, **kwargs) 34 | torch.cuda.synchronize() 35 | time_s = time.time() 36 | for _ in range(epoch): 37 | _ = func(*args, **kwargs) 38 | torch.cuda.synchronize() 39 | time_e = time.time() - time_s 40 | return time_e 41 | 42 | 43 | def main(): 44 | # classic config 45 | # batch_size = 4 46 | # num head = [32, 16, 8] 47 | # seqlen = 4096 48 | # head dim = [64, 128, 256] 49 | 50 | # BS, HEAD, SEQLEN, DIM = 1, 2, 4 * 1024, 64 51 | BS, HEAD, SEQLEN, DIM = 2, 8, 2 * 1024, 64 52 | 53 | q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16) 54 | # q = (torch.arange(SEQLEN * DIM, device="cuda").reshape(BS, HEAD, SEQLEN, DIM) * 0.0001).half() 55 | # k = (torch.arange(SEQLEN * DIM, device="cuda").reshape(BS, HEAD, SEQLEN, DIM) * 0.0001).half() 56 | # v = (torch.arange(SEQLEN * DIM, device="cuda").reshape(BS, HEAD, SEQLEN, DIM) * 0.0001).half() 57 | 58 | warmup = 10 59 | epoch = 10 60 | 61 | debug_mode = False 62 | is_causal = True 63 | sm_scale = 1.0 / math.sqrt(SEQLEN); 64 | # sm_scale = 1.0 65 | 66 | base_time = run_benchmark(epoch, warmup, self_attention, q, k, v, causal=is_causal, sm_scale=sm_scale) 67 | print("baseline: \n", base_time * 1000 / epoch) 68 | flash2_time = run_benchmark(epoch, warmup, flash_attention_v2_cutlass, q, k, v, is_causal, sm_scale) 69 | print("flash2_cutlass_ref: \n", flash2_time * 1000 / epoch) 70 | 71 | fq = q.transpose(1, 2) 72 | fk = k.transpose(1, 2) 73 | fv = v.transpose(1, 2) 74 | 75 | official_ref_time = run_benchmark(epoch, warmup, flash_attn_func_offical, fq, fk, fv, causal=is_causal, softmax_scale=sm_scale) 76 | print("official_ref: \n", official_ref_time * 1000 / epoch) 77 | 78 | 79 | baseline = self_attention(q, k, v, causal=is_causal, sm_scale=sm_scale) 80 | flash2_cutlass_ref, _ = flash_attention_v2_cutlass(q, k, v, is_causal, sm_scale) 81 | official_result = flash_attn_func_offical(fq, fk, fv, causal=is_causal, softmax_scale=sm_scale) 82 | 83 | # print(baseline) 84 | # print(flash2_cutlass_ref) 85 | # print(official_result) 86 | 87 | assert torch.allclose(baseline, flash2_cutlass_ref, rtol=0, atol=1e-2) 88 | 89 | 90 | if __name__ == "__main__": 91 | epoch = 1 92 | for _ in range(epoch): 93 | main() 94 | 95 | 96 | -------------------------------------------------------------------------------- /flash_attention_cutlass/build.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from packaging.version import parse, Version 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import ( 7 | BuildExtension, 8 | CUDAExtension, 9 | CUDA_HOME, 10 | ) 11 | 12 | # package name managed by pip, which can be remove by `pip uninstall tiny_pkg` 13 | PACKAGE_NAME = "tiny_attention_cutlass" 14 | 15 | ext_modules = [] 16 | generator_flag = [] 17 | cc_flag = [] 18 | cc_flag.append("-gencode") 19 | cc_flag.append("arch=compute_80,code=sm_80") 20 | 21 | 22 | # helper function to get cuda version 23 | def get_cuda_bare_metal_version(cuda_dir): 24 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 25 | output = raw_output.split() 26 | release_idx = output.index("release") + 1 27 | bare_metal_version = parse(output[release_idx].split(",")[0]) 28 | 29 | return raw_output, bare_metal_version 30 | 31 | 32 | if CUDA_HOME is not None: 33 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 34 | if bare_metal_version >= Version("11.8"): 35 | cc_flag.append("-gencode") 36 | cc_flag.append("arch=compute_90,code=sm_90") 37 | 38 | # ninja build does not work unless include_dirs are abs path 39 | this_dir = os.path.dirname(os.path.abspath(__file__)) 40 | 41 | # cuda module 42 | ext_modules.append( 43 | CUDAExtension( 44 | # package name for import 45 | name="attention_cutlass", 46 | sources=[ 47 | "csrc/attention_api.cpp", 48 | "csrc/flash_attention.cu", 49 | "csrc/flash_api.cpp", 50 | ], 51 | extra_compile_args={ 52 | # add c compile flags 53 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 54 | # add nvcc compile flags 55 | "nvcc": [ 56 | "-O3", 57 | "-std=c++17", 58 | "-U__CUDA_NO_HALF_OPERATORS__", 59 | "--use_fast_math", 60 | "-lineinfo", 61 | "--ptxas-options=-v", 62 | "--ptxas-options=-O2", 63 | "-U__CUDA_NO_HALF_OPERATORS__", 64 | "-U__CUDA_NO_HALF_CONVERSIONS__", 65 | "-U__CUDA_NO_HALF2_OPERATORS__", 66 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 67 | "--expt-relaxed-constexpr", 68 | "--expt-extended-lambda", 69 | "--use_fast_math", 70 | 71 | ] 72 | + generator_flag 73 | + cc_flag, 74 | }, 75 | include_dirs=[ 76 | Path(this_dir) / "csrc", 77 | Path(this_dir) / "include", 78 | Path(this_dir) / "deps/cutlass/include", 79 | Path(this_dir) / "deps/cutlass/tools/utils/include" , 80 | Path(this_dir) / "deps/cutlass/examples/common" , 81 | # Path(this_dir) / "some" / "thing" / "more", 82 | ], 83 | ) 84 | ) 85 | 86 | setup( 87 | name=PACKAGE_NAME, 88 | packages=find_packages( 89 | exclude=( 90 | "build", 91 | "csrc", 92 | "include", 93 | "tests", 94 | "dist", 95 | "docs", 96 | "benchmarks", 97 | ) 98 | ), 99 | description="Attention mechanism implement by CUDA", 100 | ext_modules=ext_modules, 101 | cmdclass={ "build_ext": BuildExtension}, 102 | python_requires=">=3.7", 103 | install_requires=[ 104 | "torch", 105 | "einops", 106 | "packaging", 107 | "ninja", 108 | ], 109 | ) 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | template 3 | struct MaxOp { 4 | __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } 5 | }; 6 | 7 | template <> 8 | struct MaxOp { 9 | // This is slightly faster 10 | __device__ inline float operator()(float const &x, float const &y) { return max(x, y); } 11 | }; 12 | 13 | //////////////////////////////////////////////////////////////////////////////////////////////////// 14 | 15 | template 16 | struct SumOp { 17 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 18 | }; 19 | 20 | //////////////////////////////////////////////////////////////////////////////////////////////////// 21 | 22 | template 23 | struct Allreduce { 24 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 25 | template 26 | static __device__ inline T run(T x, Operator &op) { 27 | constexpr int OFFSET = THREADS / 2; 28 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 29 | return Allreduce::run(x, op); 30 | } 31 | }; 32 | 33 | //////////////////////////////////////////////////////////////////////////////////////////////////// 34 | 35 | template<> 36 | struct Allreduce<2> { 37 | template 38 | static __device__ inline T run(T x, Operator &op) { 39 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 40 | return x; 41 | } 42 | }; 43 | 44 | template 45 | __device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { 46 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 47 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 48 | CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); 49 | #pragma unroll 50 | for (int mi = 0; mi < size<0>(tensor); mi++) { 51 | summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); 52 | #pragma unroll 53 | for (int ni = 1; ni < size<1>(tensor); ni++) { 54 | summary(mi) = op(summary(mi), tensor(mi, ni)); 55 | } 56 | } 57 | } 58 | 59 | template 60 | __device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { 61 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 62 | #pragma unroll 63 | for (int i = 0; i < size(dst); i++){ 64 | // NOTE: 4表示4个线程, 因为在SM80_16x8x16_F32F16F16F32_TN中, 65 | // 每组每行就是4个线程处理8个value的, 每个线程处理2个value 66 | dst(i) = Allreduce<4>::run(src(i), op); 67 | } 68 | } 69 | 70 | template 71 | __device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { 72 | // NOTE: 遍历tensor每行, 记录到summary中 73 | // reduce 当前thread的max 74 | thread_reduce_(tensor, summary, op); 75 | // NOTE: 二分法对summary[]进行reduce 76 | // reduce thread间的max 77 | quad_allreduce_(summary, summary, op); 78 | } 79 | 80 | 81 | template 82 | __device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ 83 | MaxOp max_op; 84 | reduce_(tensor, max, max_op); 85 | } 86 | 87 | template 88 | __device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ 89 | SumOp sum_op; 90 | reduce_(tensor, sum, sum_op); 91 | } 92 | 93 | -------------------------------------------------------------------------------- /flash_attention_cutlass/standalone_src/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | template 3 | struct MaxOp { 4 | __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } 5 | }; 6 | 7 | template <> 8 | struct MaxOp { 9 | // This is slightly faster 10 | __device__ inline float operator()(float const &x, float const &y) { return max(x, y); } 11 | }; 12 | 13 | //////////////////////////////////////////////////////////////////////////////////////////////////// 14 | 15 | template 16 | struct SumOp { 17 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 18 | }; 19 | 20 | //////////////////////////////////////////////////////////////////////////////////////////////////// 21 | 22 | template 23 | struct Allreduce { 24 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 25 | template 26 | static __device__ inline T run(T x, Operator &op) { 27 | constexpr int OFFSET = THREADS / 2; 28 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 29 | return Allreduce::run(x, op); 30 | } 31 | }; 32 | 33 | //////////////////////////////////////////////////////////////////////////////////////////////////// 34 | 35 | template<> 36 | struct Allreduce<2> { 37 | template 38 | static __device__ inline T run(T x, Operator &op) { 39 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 40 | return x; 41 | } 42 | }; 43 | 44 | template 45 | __device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { 46 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 47 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 48 | CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); 49 | #pragma unroll 50 | for (int mi = 0; mi < size<0>(tensor); mi++) { 51 | summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); 52 | #pragma unroll 53 | for (int ni = 1; ni < size<1>(tensor); ni++) { 54 | summary(mi) = op(summary(mi), tensor(mi, ni)); 55 | } 56 | } 57 | } 58 | 59 | template 60 | __device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { 61 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 62 | #pragma unroll 63 | for (int i = 0; i < size(dst); i++){ 64 | // NOTE: 4表示4个线程, 因为在SM80_16x8x16_F32F16F16F32_TN中, 65 | // 每组每行就是4个线程处理8个value的, 每个线程处理2个value 66 | dst(i) = Allreduce<4>::run(src(i), op); 67 | } 68 | } 69 | 70 | template 71 | __device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { 72 | // NOTE: 遍历tensor每行, 记录到summary中 73 | // reduce 当前thread的max 74 | thread_reduce_(tensor, summary, op); 75 | // NOTE: 二分法对summary[]进行reduce 76 | // reduce thread间的max 77 | quad_allreduce_(summary, summary, op); 78 | } 79 | 80 | 81 | template 82 | __device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ 83 | MaxOp max_op; 84 | reduce_(tensor, max, max_op); 85 | } 86 | 87 | template 88 | __device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ 89 | SumOp sum_op; 90 | reduce_(tensor, sum, sum_op); 91 | } 92 | 93 | -------------------------------------------------------------------------------- /flash_attention_py/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import tiny_flash_attn 4 | from tiny_flash_attn_triton import flash_attn_triton, ref_attn 5 | 6 | from flash_attn import flash_attn_func as flash_attn_func_cuda 7 | 8 | class BaseAttention: 9 | def __init__(self): 10 | pass 11 | 12 | def attention(self, q, k, v): 13 | s = self.softmax(q @ k.T, dim=1) 14 | return s @ v 15 | 16 | def softmax(self, input, dim): 17 | raise "unimplement" 18 | 19 | class NativeAttention(BaseAttention): 20 | def softmax(self, input, dim): 21 | return torch.softmax(input, dim) 22 | 23 | class SafeAttention(BaseAttention): 24 | def softmax(self, input, dim): 25 | ''' 26 | softmax with safe 27 | ''' 28 | row_max = torch.max(input, dim=dim).values[:, None] 29 | # read++ 30 | input_safe = input - row_max 31 | softmax_numerator = torch.exp(input_safe) 32 | # read++ 33 | softmax_denominator = torch.sum(softmax_numerator, dim=1)[:, None] 34 | # read++ 35 | return softmax_numerator / softmax_denominator 36 | 37 | class OnlineSafeAttention(BaseAttention): 38 | ''' 39 | The tiny flash attention implement 40 | ''' 41 | def __init__(self): 42 | self.BLOCK_M = 4 43 | 44 | def attention(self, q, k, v, device='cuda'): 45 | return tiny_flash_attn.flash_attn(q, k, v, device, self.BLOCK_M) 46 | 47 | def attention_v1(self, q, k, v, device='cuda'): 48 | return tiny_flash_attn.flash_attn_v1(q, k, v, device, self.BLOCK_M) 49 | 50 | def attention_v2(self, q, k, v, device='cuda'): 51 | return tiny_flash_attn.flash_attn_v2(q, k, v, device, self.BLOCK_M) 52 | 53 | def attention_v2_multihead(self, q, k, v, device='cuda'): 54 | return tiny_flash_attn.flash_attn_v2_multihead(q, k, v, device, self.BLOCK_M) 55 | 56 | def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16): 57 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 58 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 59 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 60 | return q, k, v 61 | 62 | def main(): 63 | BS, HEAD, SEQLEN, DIM = 1, 1, 128, 64 64 | 65 | q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16) 66 | # softmax_scale = math.sqrt(q.shape[-1]) 67 | # q /= softmax_scale 68 | 69 | native = NativeAttention() 70 | safe = SafeAttention() 71 | online = OnlineSafeAttention() 72 | 73 | native_result = native.attention(torch.squeeze(q), torch.squeeze(k), torch.squeeze(v)) 74 | safe_result = safe.attention(torch.squeeze(q), torch.squeeze(k), torch.squeeze(v)) 75 | online_result1 = online.attention_v1(torch.squeeze(q), torch.squeeze(k), torch.squeeze(v)) 76 | online_result2 = online.attention_v2(torch.squeeze(q), torch.squeeze(k), torch.squeeze(v)) 77 | online_result2_multi = online.attention_v2_multihead(q, k, v) 78 | 79 | causal=False 80 | ref_result = ref_attn(q, k, v, causal=causal) 81 | triton_result = flash_attn_triton(q, k, v, causal=causal) 82 | official_result = flash_attn_func_cuda(q, k, v, causal=causal) 83 | 84 | # print(native_result) 85 | # print(safe_result) 86 | # print(online_result1) 87 | # print(online_result2) 88 | # print(online_result2_multi) 89 | print(ref_result) 90 | print(triton_result) 91 | print(official_result) 92 | 93 | # Assert attention output is same. 94 | # But it may have precision loss compared with native. 95 | assert torch.allclose(native_result, safe_result, rtol=0, atol=1e-2) 96 | assert torch.allclose(safe_result, online_result1.half(), rtol=0, atol=1e-2) 97 | assert torch.allclose(online_result2, online_result1, rtol=0, atol=1e-2) 98 | assert torch.allclose(online_result2_multi, online_result1, rtol=0, atol=1e-2) 99 | assert torch.allclose(triton_result, ref_result, rtol=0, atol=1e-2) 100 | 101 | # assert torch.allclose(official_result, triton_result, rtol=0, atol=1e-2) 102 | # assert torch.allclose(official_result, ref_result, rtol=0, atol=1e-2) 103 | 104 | if __name__ == "__main__": 105 | main() 106 | 107 | -------------------------------------------------------------------------------- /flash_attention_cuda/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from attention_cuda import self_attention_cuda, flash_attention_v1_cuda, flash_attention_v2_cuda 4 | from flash_attn_triton import flash_attn_triton, ref_attn 5 | import math 6 | import time 7 | # offical flash attention implement 8 | from flash_attn import flash_attn_func as flash_attn_func_offical 9 | 10 | ''' 11 | simple attention implement without multi head 12 | ''' 13 | 14 | def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16): 15 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 16 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 17 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 18 | return q, k, v 19 | 20 | def self_attention(q, k, v): 21 | # TODO: extract sm_scale 22 | q = q / math.sqrt(q.shape[-1]) 23 | score = torch.matmul(q, k.transpose(-2, -1)) 24 | s = torch.softmax(score, dim=-1) 25 | return s @ v 26 | 27 | 28 | def main(): 29 | BS, HEAD, SEQLEN, DIM = 1000, 1, 64, 64 30 | 31 | q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16) 32 | # q = torch.arange(16, dtype=torch.float32, device="cuda").reshape(1, 1, 4,4) 33 | # k = torch.arange(16, dtype=torch.float32, device="cuda").reshape(1, 1, 4,4) 34 | # v = torch.arange(16, dtype=torch.float32, device="cuda").reshape(1, 1, 4,4) 35 | warmup = 10 36 | epoch = 10 37 | sm_scale = 1.0 / math.sqrt(q.shape[-1]) 38 | 39 | for _ in range(warmup): 40 | # _ = self_attention_cuda(q, k, v) 41 | _ = flash_attention_v2_cuda(q, k, v) 42 | 43 | torch.cuda.synchronize() 44 | base_time = time.time() 45 | for _ in range(epoch): 46 | baseline = self_attention(q, k, v) 47 | torch.cuda.synchronize() 48 | base_time = time.time() - base_time 49 | 50 | # naive_time = time.time() 51 | # for _ in range(epoch): 52 | # naive_cuda_ref = self_attention_cuda(q, k, v) 53 | # torch.cuda.synchronize() 54 | # naive_time = time.time() - naive_time 55 | 56 | # flash1_time = time.time() 57 | # for _ in range(epoch): 58 | # flash1_cuda_ref = flash_attention_v1_cuda(q, k, v) 59 | # torch.cuda.synchronize() 60 | # flash1_time = time.time() - flash1_time 61 | 62 | torch.cuda.synchronize() 63 | flash2_time = time.time() 64 | for _ in range(epoch): 65 | flash2_cuda_ref = flash_attention_v2_cuda(q, k, v) 66 | torch.cuda.synchronize() 67 | flash2_time = time.time() - flash2_time 68 | 69 | flash_triton_time = time.time() 70 | for _ in range(epoch): 71 | flash_triton_ref = flash_attn_triton(q, k, v, causal=False, sm_scale=sm_scale).half() 72 | torch.cuda.synchronize() 73 | flash_triton_time = time.time() - flash_triton_time 74 | 75 | torch.cuda.synchronize() 76 | official_ref_time = time.time() 77 | for _ in range(epoch): 78 | official_result = flash_attn_func_offical(q, k, v, causal=False, softmax_scale=sm_scale) 79 | torch.cuda.synchronize() 80 | official_ref_time = time.time() - official_ref_time 81 | 82 | 83 | # print time in ms 84 | print("baseline: \n", base_time * 1000 / epoch) 85 | # print("naive_cuda_ref: \n", naive_time * 1000 / epoch) 86 | # print("flash1_cuda_ref: \n", flash1_time * 1000 / epoch) 87 | print("flash2_cuda_ref: \n", flash2_time * 1000 / epoch) 88 | print("flash_triton_ref: \n", flash_triton_time * 1000 / epoch) 89 | print("official_ref: \n", official_ref_time * 1000 / epoch) 90 | 91 | print("baseline: \n", baseline) 92 | # print("naive_cuda_ref: \n", naive_cuda_ref) 93 | # print("flash1_cuda_ref: \n", flash1_cuda_ref) 94 | print("flash2_cuda_ref: \n", flash2_cuda_ref) 95 | # print("flash_triton_ref: \n", flash_triton_ref) 96 | print("official_ref: \n", official_result) 97 | 98 | # assert torch.allclose(baseline, naive_cuda_ref, rtol=0, atol=1e-2) 99 | # assert torch.allclose(baseline, flash1_cuda_ref, rtol=0, atol=1e-2) 100 | assert torch.allclose(baseline, flash2_cuda_ref, rtol=0, atol=1e-2) 101 | # assert torch.allclose(baseline, official_result, rtol=0, atol=1e-2) 102 | 103 | if __name__ == "__main__": 104 | epoch = 1 105 | for _ in range(epoch): 106 | main() 107 | 108 | -------------------------------------------------------------------------------- /flash_attention_cuda/csrc/self_attention.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "attention_api.cuh" 10 | 11 | template 12 | __global__ void naive_nrow_gemm(scalar_t *A, scalar_t *B, scalar_t *C, scalar_t a, scalar_t b, 13 | int64_t M, int64_t N, int64_t K, int64_t mBlock); 14 | 15 | template 16 | __global__ void naive_pv(scalar_t *P, scalar_t *V, scalar_t *O, int M, int N, int mBlock); 17 | 18 | template 19 | __global__ void row_softmax(scalar_t *input, scalar_t *output, int n); 20 | 21 | // TODO: Add support for half 22 | torch::Tensor self_attention_cuda(torch::Tensor q, torch::Tensor k, torch::Tensor v) { 23 | CHECK_INPUT(q); 24 | CHECK_INPUT(k); 25 | CHECK_INPUT(v); 26 | 27 | auto out = torch::zeros_like(q); 28 | // TODO: multihead 29 | // seqlen 30 | auto m = q.size(0); 31 | // dim 32 | auto n = q.size(1); 33 | 34 | int64_t mBlock = 2; 35 | assert(m % mBlock == 0 && "mBlock should align"); 36 | 37 | float sm_scale = 1.f / sqrtf(static_cast(n)); 38 | 39 | // Create intermediate date for the new shape of qk 40 | torch::TensorOptions options = q.options(); 41 | std::vector shape = {m, m}; 42 | torch::Tensor qk = torch::empty(shape, options); 43 | 44 | dim3 qk_block(m / mBlock, 1, 1); 45 | // NOTE: AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) 46 | // We need a way of determining at runtime what type a tensor is and then 47 | // selectively call functions with the corresponding correct type signature. 48 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(q.scalar_type(), "QK", ([&] { 49 | naive_nrow_gemm<<<1, qk_block>>>( 50 | q.data_ptr(), k.data_ptr(), 51 | qk.data_ptr(), sm_scale, 0.f, m, m, n, mBlock); 52 | })); 53 | // Wait until kernel finish. 54 | cudaDeviceSynchronize(); 55 | CUDA_ERROR_CHECK(cudaGetLastError()); 56 | 57 | 58 | // QK[M, M] 59 | // TODO: too much thread may cause CUDA crash. 60 | dim3 sm_block(m, 1, 1); 61 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(qk.scalar_type(), "softmax(QK)", ([&] { 62 | row_softmax<<<1, sm_block>>>( 63 | qk.data_ptr(), qk.data_ptr(), m); 64 | })); 65 | cudaDeviceSynchronize(); 66 | CUDA_ERROR_CHECK(cudaGetLastError()); 67 | 68 | 69 | // QK[M, M] @ V[M, N] 70 | dim3 qkv_block(m / mBlock, 1, 1); 71 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "softmax(QK)V", ([&] { 72 | naive_pv<<<1, qkv_block>>>( 73 | qk.data_ptr(), v.data_ptr(), 74 | out.data_ptr(), m, n, mBlock); 75 | })); 76 | // We can remove this sync and let user call torch.cuda.synchronize() 77 | cudaDeviceSynchronize(); 78 | CUDA_ERROR_CHECK(cudaGetLastError()); 79 | 80 | return out; 81 | } 82 | 83 | // naive gemm implement with slice-k 84 | // perform C = aA@B + bC 85 | // A[M, K] x B[K, N] = C[M, N] 86 | // each thread process mblock rows of A 87 | // TODO: how to make data type more general 88 | template 89 | __global__ void naive_nrow_gemm(scalar_t *A, scalar_t *B, scalar_t *C, scalar_t a, scalar_t b, 90 | int64_t M, int64_t N, int64_t K, int64_t mBlock) { 91 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 92 | 93 | // each thread process a range of rows 94 | idx *= mBlock; 95 | 96 | // A[mBlock, K] x B[N, K].T = C[mBlock, N] 97 | for (int i = idx; i < idx + mBlock; i++) { 98 | for (int j = 0; j < N; j++) { 99 | scalar_t sum = 0.f; 100 | for (int k = 0; k < K; k++) { 101 | sum += A[i * K + k] * B[j * K + k]; 102 | } 103 | // C[M, N] 104 | // C = aA@B + bC 105 | C[i * N + j] = a * sum + b * C[i * N + j]; 106 | } 107 | } 108 | } 109 | 110 | // perform QK[M, M] @ V[M, N] 111 | template 112 | __global__ void naive_pv(scalar_t *P, scalar_t *V, scalar_t *O, int M, int N, 113 | int mBlock) { 114 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 115 | 116 | // each thread process a range of rows 117 | idx *= mBlock; 118 | 119 | int K = M; 120 | // P[mBlock, M] x V[M, N] = O[mBlock, N] 121 | for (int i = idx; i < idx + mBlock; i++) { 122 | for (int j = 0; j < N; j++) { 123 | scalar_t sum = 0.f; 124 | for (int k = 0; k < K; k++) { 125 | sum += P[i * K + k] * V[k * N + j]; 126 | } 127 | // C[M, N] 128 | O[i * N + j] = sum; 129 | } 130 | } 131 | } 132 | 133 | // each thread process one row of softmax 134 | template 135 | __global__ void row_softmax(scalar_t *input, scalar_t *output, int n) { 136 | // assume id will not exceed row number of input 137 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 138 | 139 | scalar_t row_max = -INFINITY; 140 | scalar_t sum = 0.f; 141 | 142 | // Find max 143 | for (int i = 0; i < n; i++) { 144 | row_max = max(input[idx * n + i], row_max); 145 | } 146 | 147 | // Compute numerator and denominator 148 | for (int i = 0; i < n; i++) { 149 | output[idx * n + i] = exp(input[idx * n + i] - row_max); 150 | sum += output[idx * n + i]; 151 | } 152 | 153 | // Compute softmax 154 | for (int i = 0; i < n; i++) { 155 | output[idx * n + i] /= sum; 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # Tiny FlashAttention 2 | 3 | WIP 4 | 5 | 一个简易的[flash attention](https://github.com/Dao-AILab/flash-attention)实现。 6 | 7 | - [python version](#flash-attention-2) 8 | * [x] [naive pure python code](./flash_attention_py/tiny_flash_attn.py) 9 | - [triton version](#triton-flash-attention-2) 10 | * [x] [triton code](./flash_attention_py/tiny_flash_attn_triton.py) 11 | - [c version] 12 | * TODO: [naive pure c code]() 13 | * [x] [naive cuda code standalone](./flash_attention_cuda/standalone_src) 14 | * [x] [naive cuda code python binding](./flash_attention_cutlass/csrc/flash_attention.cu) 15 | * [x] [cutlass cuda code](./flash_attention_cutlass/csrc/flash_attention.cu) 16 | - [rust version] 17 | 18 | 19 | ## algo 20 | 21 | - attention 22 | - softmax 23 | * $s(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}$ 24 | * 指数容易溢出导致精度损失 25 | - safe softmax 26 | * $s(x_i) = \frac{e^{x_i - max(x)}}{\sum_j{e^{x_j - max(x)}}} = \frac{e^{-max(x)} \times e^{x_i}}{e^{-max(x)} \times \sum_j{e^x_j}}$ 27 | * 指数部分减去一个最大值 28 | - online softmax 29 | * 上述softmax的问题在于, 分子处的max和分母的sum都需要读取整个向量以获取max和sum值, 缓存(SRAM)不够友好 30 | * online softmax的算法是分子分母分开算, 最后再整合 31 | 1. 分块计算max, 并迭代出分母的sum, 得出normalization factor 32 | - TODO 33 | 2. scaling 34 | - flash attention 1 35 | * tiling 36 | * SRAM 37 | - flash attention 2 38 | 39 | ## tips 40 | 41 | - `softmax_lse` 42 | * lse表示LogSumExp? 43 | * [lse主要解决计算Softmax或CrossEntropy2上溢(overflow)或下溢(underflow)的问题](https://www.apispace.com/news/post/13827.html) 44 | - `softmax_scale` 45 | * `q @ k.T / softmax_scale` 46 | * 添加`softmax_scale`后精度损失没有那么严重了(但依然存在) 47 | 48 | ## flow 49 | 50 | - softmax的online方法: scale(更新) + 累加 51 | - s@v的online方法: scale(更新) + 累加 52 | 1. 更新(旧O) + 新O 53 | 2. **更新方法: 更新max, 更新分母** 54 | 1. 更新max: 分子分母乘上$e^{max_old - max_new}$ 55 | 2. **更新分母: 先乘旧分母, 再除新分母** 56 | 57 | 1. 想清楚是怎么分块计算的 58 | 2. 再考虑块的值是怎么来的 59 | 60 | 不分块的情况, 设Q, K, V的shape=(N, d) 61 | 62 | softmax结果和V矩阵乘: 63 | 64 | ``` 65 | s = Q @ K.T = (N, d) @ (d, N) = (N, N) 66 | attn = s @ V = (N, N) @ (N, d) = (N, d) 67 | ``` 68 | 69 | 分块native, softmax的部分和V的部分相乘, 70 | 71 | ``` 72 | si = Qi @ Kj.T = (N/2, d) @ (d, N/2) = (N/2, N/2) 73 | attni[N/2, :] = si @ Vj = (N/2, N/2) @ (N/2, d) = (N/2, d) 74 | ``` 75 | 76 | 77 | 分块online 78 | 79 | TODO: img 80 | 81 | 所以output是要相加的! 82 | 83 | 84 | ## Keynote 85 | 86 | - the matmul 87 | - the shape 88 | - the algo 89 | 90 | ## The algo 91 | 92 | - attention 93 | - softmax 94 | * $s(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}$ 95 | - safe softmax 96 | - online softmax 97 | * algo1 98 | * impl algo2 99 | - flash attention 1 100 | * tiling 101 | * SRAM 102 | - flash attention 2 103 | 104 | ### 3 pass online softmax 105 | 106 | 1. pass1: 分块统计max 107 | 2. pass2: 分块求分母的sum 108 | 3. pass3: 执行softmax(xi) 109 | 110 | ### 2 pass online softmax 111 | 112 | 1. pass1: 分块统计max的同时动态更新分母的sum 113 | 2. pass2: 执行softmax(xi) 114 | 115 | $$d'_i = d'_{i-1}e^{m_{i-1} - m_{i}} + e^{x_i - m_{i}}$$ 116 | 117 | $d'_{i-1}e^{m_{i-1} - m_{i}}$就能将过时的max给替换掉了 118 | 119 | ### 2 pass online attention 120 | 121 | 矩阵乘法满足结合律 122 | 123 | ### 1 pass online attention 124 | 125 | ### 分块OV的计算 126 | 127 | 对于相同位置的O 128 | 129 | 130 | 131 | ## sum 132 | 133 | - online处理是会导致**精度损失**的(至少在tiny版本上) 134 | 135 | ## flash attention 2 136 | 137 | - flash attention 1的问题 138 | * 频繁的li, mi, oi更新 139 | + 一方面是频繁的非矩阵乘法 140 | + oi最后更新 141 | + 一方面是频繁的写 142 | + 内外循环顺序 143 | 144 | 1. 减少非矩阵乘法(non-matmul)操作 145 | 2. 并行计算attn, 即使是单头 146 | 3. 考虑多在thread block内计算, 减少跨组通信 147 | 148 | - flow 149 | * 与flash attention1对比 150 | + 局部值(oi, mi, li)就不用多次更新了, 一轮外部循环一行就能处理完成 151 | 152 | - tips 153 | * flash attention 2中分块的形状要特别注意 154 | 155 | ```python 156 | # flash attention 1 的循环 157 | for j in range(k_block_num): 158 | kj = K_BLOCKS[j] 159 | vj = V_BLOCKS[j] 160 | 161 | for i in range(q_block_num): 162 | qi = Q_BLOCKS[i] 163 | 164 | # flash attention 2 的循环 165 | for j in range(k_block_num): 166 | qi = Q_BLOCKS[i] 167 | 168 | for i in range(q_block_num): 169 | kj = K_BLOCKS[j] 170 | vj = V_BLOCKS[j] 171 | ``` 172 | 173 | ## triton flash attention 2 174 | 175 | [source code](./flash_attention-py/tiny_flash_attn_triton.py) 176 | 177 | 用triton实现一个shape为`bs, head, seqlen, dim`的qkv的attention。 178 | 179 | 1. 考虑计算所需的thread blocks, 即grid 180 | - 对于flash attn 2, 可以将外层的q循环并行处理, 及每个thread执行的是一部分q和其他所有kv的attention 181 | - 对于Q的分块处理(即分seqlen, 即分token), 如果一次处理`BLOCK_M`个token, 那么一次完整的attention计算需要`cdiv(seqlen, BLOCK_M)`个thread, cdiv表示除法向上取整 182 | - 每次kernel计算只需要后两维度, 即(seqlen, dim), 那么前两个维度有多少就需要多少thread来处理。因此将`grid[1]`置为`bs * head` 183 | - 因此最终grid为`[cdiv(seqlen, BLOCKM), bs * head]` 184 | 2. kernel设计, 设计并行程序 185 | - 计算thread处理各自负责的数据 186 | - 计算`(bs, head, seqlen, dim)`访问`head+1`时需要的offset 187 | * 可以使用`Tensor.stride(dim)`计算访问dim这个维度的下一个元素时所需跳过的元素数 188 | * 根据`grid[1]`记录而`bs*head`的大小和`q.stride(1)`, thread找到自己负责的范围 189 | - **使用`tl.make_block_ptr()`API分块读取qkv**, q根据`BLOCK_M`分块, kv根据`BLOCK_N`分块 190 | * 使用base参数找到正确的(bs, head)位置 191 | * 使用shape和order参数定义内存布局 192 | + 取`shape=(seqlen, dim)`, `order=(1, 0)`的q, v块, `order=(1, 0)`表示第二个维度在存储中的内侧 193 | + 取`shape=(dim, seqlen)`, `order=(0, 1)`的k块, `order=(0, 1)`表示第二个维度在存储中的外侧, 相当于对k做转置 194 | + API会根据order指定的顺序去构造所需的shape 195 | * 使用`block_shape`定义整个块的shape, `shape`参数则是每次读取整个块中的一部分的大小 196 | * 使用`strides`参数定义每次q, k, v块指针递增时的步长 197 | - Q根据`BLOCK_M`分块, K和V根据`BLOCK_N`分块 198 | 3. flash attention 2算法 199 | - 因为CSE(common subexpression elimination), LICM(loop invariant code motion)不支持`exp()`所以使用`exp2()`代替, 即`2^x` 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | ## ref 217 | 218 | - [Online normalizer calculation for softmax](https://arxiv.org/abs/1805.02867) 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, [&] { 14 | /// some_function(...); 15 | /// }); 16 | /// ``` 17 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 18 | [&] { \ 19 | if (COND) { \ 20 | constexpr static bool CONST_NAME = true; \ 21 | return __VA_ARGS__(); \ 22 | } else { \ 23 | constexpr static bool CONST_NAME = false; \ 24 | return __VA_ARGS__(); \ 25 | } \ 26 | }() 27 | 28 | #define FP16_SWITCH(COND, ...) \ 29 | [&] { \ 30 | if (COND) { \ 31 | using elem_type = cutlass::half_t; \ 32 | return __VA_ARGS__(); \ 33 | } else { \ 34 | using elem_type = cutlass::bfloat16_t; \ 35 | return __VA_ARGS__(); \ 36 | } \ 37 | }() 38 | 39 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ 40 | [&] { \ 41 | if (HEADDIM <= 32) { \ 42 | constexpr static int kHeadDim = 32; \ 43 | return __VA_ARGS__(); \ 44 | } else if (HEADDIM <= 64) { \ 45 | constexpr static int kHeadDim = 64; \ 46 | return __VA_ARGS__(); \ 47 | } else if (HEADDIM <= 96) { \ 48 | constexpr static int kHeadDim = 96; \ 49 | return __VA_ARGS__(); \ 50 | } else if (HEADDIM <= 128) { \ 51 | constexpr static int kHeadDim = 128; \ 52 | return __VA_ARGS__(); \ 53 | } else if (HEADDIM <= 160) { \ 54 | constexpr static int kHeadDim = 160; \ 55 | return __VA_ARGS__(); \ 56 | } else if (HEADDIM <= 192) { \ 57 | constexpr static int kHeadDim = 192; \ 58 | return __VA_ARGS__(); \ 59 | } else if (HEADDIM <= 224) { \ 60 | constexpr static int kHeadDim = 224; \ 61 | return __VA_ARGS__(); \ 62 | } else if (HEADDIM <= 256) { \ 63 | constexpr static int kHeadDim = 256; \ 64 | return __VA_ARGS__(); \ 65 | } \ 66 | }() 67 | 68 | 69 | #define WARP_SWITCH(COND, CONST_NAME, ...) \ 70 | [&] { \ 71 | if (COND == 4) { \ 72 | constexpr static int CONST_NAME = 4; \ 73 | return __VA_ARGS__(); \ 74 | } else if (COND == 8) { \ 75 | constexpr static int CONST_NAME = 8; \ 76 | return __VA_ARGS__(); \ 77 | } else { \ 78 | constexpr static int CONST_NAME = 2; \ 79 | return __VA_ARGS__(); \ 80 | } \ 81 | }() 82 | 83 | #define BLOCKM_SWITCH(COND, CONST_NAME, ...) \ 84 | [&] { \ 85 | if (COND == 64) { \ 86 | constexpr static int CONST_NAME = 64; \ 87 | return __VA_ARGS__(); \ 88 | } else if (COND == 128) { \ 89 | constexpr static int CONST_NAME = 128; \ 90 | return __VA_ARGS__(); \ 91 | } else if (COND == 256) { \ 92 | constexpr static int CONST_NAME = 256; \ 93 | return __VA_ARGS__(); \ 94 | } else { \ 95 | constexpr static int CONST_NAME = 64; \ 96 | return __VA_ARGS__(); \ 97 | } \ 98 | }() 99 | 100 | #define BLOCKN_SWITCH(COND, CONST_NAME, ...) \ 101 | [&] { \ 102 | if (COND == 32) { \ 103 | constexpr static int CONST_NAME = 32; \ 104 | return __VA_ARGS__(); \ 105 | } else if (COND == 64) { \ 106 | constexpr static int CONST_NAME = 64; \ 107 | return __VA_ARGS__(); \ 108 | } else if (COND == 128) { \ 109 | constexpr static int CONST_NAME = 128; \ 110 | return __VA_ARGS__(); \ 111 | } else if (COND == 256) { \ 112 | constexpr static int CONST_NAME = 256; \ 113 | return __VA_ARGS__(); \ 114 | } else { \ 115 | constexpr static int CONST_NAME = 64; \ 116 | return __VA_ARGS__(); \ 117 | } \ 118 | }() 119 | 120 | #define STAGE_SWITCH(COND, CONST_NAME, ...) \ 121 | [&] { \ 122 | if (COND == 2) { \ 123 | constexpr static int CONST_NAME = 2; \ 124 | return __VA_ARGS__(); \ 125 | } else if (COND == 3) { \ 126 | constexpr static int CONST_NAME = 3; \ 127 | return __VA_ARGS__(); \ 128 | } else if (COND == 4) { \ 129 | constexpr static int CONST_NAME = 4; \ 130 | return __VA_ARGS__(); \ 131 | } else if (COND == 5) { \ 132 | constexpr static int CONST_NAME = 5; \ 133 | return __VA_ARGS__(); \ 134 | } else { \ 135 | constexpr static int CONST_NAME = 2; \ 136 | return __VA_ARGS__(); \ 137 | } \ 138 | }() 139 | -------------------------------------------------------------------------------- /flash_attention_py/tiny_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def flash_attn_v1(q, k, v, device='cuda', BLOCK_M=4): 5 | ''' 6 | The tiny flash attention implement 7 | ''' 8 | assert q.shape == k.shape 9 | assert q.shape == v.shape 10 | 11 | # Create output buffer in HBM. 12 | output_buffer = torch.zeros(v.shape, device=device) 13 | # Create denominator buffer in HBM. 14 | l = torch.zeros(v.shape[:-1], device=device)[..., None] 15 | # Create max(x) buffer in HBM. 16 | m = torch.ones(v.shape[:-1], device=device)[..., None] * -torch.inf 17 | 18 | Q_BLOCKS = torch.split(q, BLOCK_M, dim=-2) 19 | K_BLOCKS = torch.split(k, BLOCK_M, dim=-2) 20 | V_BLOCKS = torch.split(v, BLOCK_M, dim=-2) 21 | O_BLOCKS = list(torch.split(output_buffer, BLOCK_M, dim=-2)) 22 | L_BLOCKS = list(torch.split(l, BLOCK_M, dim=-2)) 23 | M_BLOCKS = list(torch.split(m, BLOCK_M, dim=-2)) 24 | 25 | k_block_num = k.shape[-2] // BLOCK_M 26 | for j in range(k_block_num): 27 | kj = K_BLOCKS[j] 28 | vj = V_BLOCKS[j] 29 | 30 | q_block_num = q.shape[-2] // BLOCK_M 31 | for i in range(q_block_num): 32 | qi = Q_BLOCKS[i] 33 | old_o = O_BLOCKS[i] 34 | old_d = L_BLOCKS[i] 35 | old_m = M_BLOCKS[i] 36 | 37 | # Compute qk. 38 | x_qkt = (qi @ kj.T) 39 | # Get local max of qk. 40 | # keepdim to avoid auto squeeze. 41 | local_m = torch.max(x_qkt, dim=1, keepdim=True).values 42 | 43 | 44 | # # MatMul operator optimization version. 45 | # # Compute new max. 46 | # new_m = torch.maximum(old_m, local_m) 47 | # # Compute numerator. e^{x - max(x)}. 48 | # safe_e = torch.exp(x_qkt - new_m) 49 | # # Compute new part of denominator. 50 | # curr_d = torch.sum(safe_e, dim=1)[:, None] 51 | # # Update denominator. 52 | # new_d = old_d * torch.exp(old_m - new_m) + curr_d 53 | # # Update old output and accumulate new output. 54 | # new_o = old_o * torch.exp(old_m - new_m) * old_d / new_d + safe_e / new_d @ vj 55 | 56 | 57 | # Flash attention 1 with many redundant mul 58 | # Compute numerator. e^{x - max(x)} 59 | safe_e = torch.exp(x_qkt - local_m) 60 | # Compute new part of denominator. 61 | curr_d = torch.sum(safe_e, dim=1, keepdim=True) 62 | 63 | # Update max. 64 | new_m = torch.maximum(local_m, old_m) 65 | # Update denominator. 66 | new_d = old_d * torch.exp(old_m - new_m) + curr_d * torch.exp(local_m - new_m) 67 | # Update old output and accumulate new output. 68 | new_o = (old_d * torch.exp(old_m - new_m) * old_o / new_d) + (torch.exp(local_m - new_m) * safe_e / new_d @ vj.float()) 69 | 70 | 71 | # Store new value. 72 | # NOTE:O_BLOCKS, L_BLOCKS, M_BLOCKS here will malloc addition memory 73 | L_BLOCKS[i] = new_d 74 | M_BLOCKS[i] = new_m 75 | O_BLOCKS[i] = new_o 76 | 77 | output_buffer = torch.cat(O_BLOCKS, dim=-2) 78 | 79 | return output_buffer 80 | 81 | def flash_attn_v2(q, k, v, device='cuda', BLOCK_M=4): 82 | ''' 83 | The tiny flash attention implement 84 | ''' 85 | assert q.shape == k.shape 86 | assert q.shape == v.shape 87 | 88 | # Create output buffer in HBM. 89 | output_buffer = torch.zeros(v.shape, device=device) 90 | 91 | Q_BLOCKS = torch.split(q, BLOCK_M, dim=-2) 92 | K_BLOCKS = torch.split(k, BLOCK_M, dim=-2) 93 | V_BLOCKS = torch.split(v, BLOCK_M, dim=-2) 94 | O_BLOCKS = list(torch.split(output_buffer, BLOCK_M, dim=-2)) 95 | 96 | q_block_num = q.shape[-2] // BLOCK_M 97 | for j in range(q_block_num): 98 | qi = Q_BLOCKS[j] 99 | old_o = O_BLOCKS[j] 100 | # Create denominator buffer in HBM. 101 | old_d = torch.zeros((BLOCK_M, 1), device=device) 102 | # Create max(x) buffer in HBM. 103 | old_m = torch.full((BLOCK_M, 1), -torch.inf, device=device) 104 | 105 | k_block_num = k.shape[-2] // BLOCK_M 106 | for i in range(k_block_num): 107 | kj = K_BLOCKS[i] 108 | vj = V_BLOCKS[i] 109 | 110 | # Compute qk. 111 | x_qkt = (qi @ kj.T) 112 | # Get local max of qk. 113 | local_m = torch.max(x_qkt, dim=1, keepdim=True).values 114 | 115 | # Compute new max. 116 | new_m = torch.maximum(old_m, local_m) 117 | # Compute numerator. i.e.: e^{x - max(x)}. 118 | safe_e = torch.exp(x_qkt - new_m) 119 | # Compute new part of denominator. 120 | curr_d = torch.sum(safe_e, dim=1, keepdim=True) 121 | # Update denominator. 122 | new_d = old_d * torch.exp(old_m - new_m) + curr_d 123 | # Update old output and accumulate new output. 124 | new_o = old_o * torch.exp(old_m - new_m) + safe_e @ vj.float() 125 | 126 | old_m = new_m 127 | old_d = new_d 128 | old_o = new_o 129 | 130 | # NOTE:O_BLOCKS here will malloc addition memory 131 | O_BLOCKS[j] = old_o / old_d 132 | 133 | output_buffer = torch.cat(O_BLOCKS, dim=-2) 134 | 135 | return output_buffer 136 | 137 | def flash_attn_v2_multihead(q, k, v, device='cpu', BLOCK_M=4): 138 | ''' 139 | The tiny flash attention implement 140 | ''' 141 | assert q.shape == k.shape 142 | assert q.shape == v.shape 143 | 144 | # NOTE: q, v, k location should not change in here 145 | q = q.to(device=device) 146 | k = k.to(device=device) 147 | v = v.to(device=device) 148 | # Create output buffer in HBM. 149 | output_buffer = torch.zeros(v.shape, device=device) 150 | 151 | Q_BLOCKS = torch.split(q, BLOCK_M, dim=-2) 152 | K_BLOCKS = torch.split(k, BLOCK_M, dim=-2) 153 | V_BLOCKS = torch.split(v, BLOCK_M, dim=-2) 154 | 155 | bs, head, seqlen, headdim = q.shape 156 | 157 | seqlen = q.shape[-2] // BLOCK_M 158 | for j in range(seqlen): 159 | qi = Q_BLOCKS[j] 160 | old_o = output_buffer[...,j * BLOCK_M: (j+1) * BLOCK_M, :] 161 | # Create denominator buffer in HBM. 162 | old_d = torch.zeros((bs, head, BLOCK_M, 1), device=device) 163 | # Create max(x) buffer in HBM. 164 | old_m = torch.full((bs, head, BLOCK_M, 1), -torch.inf, device=device) 165 | 166 | k_block_num = k.shape[-2] // BLOCK_M 167 | for i in range(k_block_num): 168 | kj = K_BLOCKS[i] 169 | vj = V_BLOCKS[i] 170 | 171 | # Compute qk. 172 | # NOTE: we need softmax_scale here in real world 173 | x_qkt = (qi @ kj.transpose(2, 3)) 174 | # Get local max of qk. 175 | # keepdim to avoid auto squeeze. 176 | # torch.max() return (max, max_index) 177 | local_m = torch.max(x_qkt, dim=-1, keepdim=True).values 178 | 179 | # Compute new max. 180 | new_m = torch.maximum(old_m, local_m) 181 | # Compute numerator. i.e.: e^{x - max(x)}. 182 | safe_e = torch.exp(x_qkt - new_m) 183 | # Compute new part of denominator. 184 | curr_d = torch.sum(safe_e, dim=-1, keepdim=True) 185 | # Update denominator. 186 | new_d = old_d * torch.exp(old_m - new_m) + curr_d 187 | # Update old output and accumulate new output. 188 | new_o = old_o * torch.exp(old_m - new_m) + safe_e @ vj.float() 189 | 190 | old_m = new_m 191 | old_d = new_d 192 | old_o = new_o 193 | 194 | output_buffer[...,j * BLOCK_M: (j+1) * BLOCK_M, :] = old_o / old_d 195 | 196 | return output_buffer 197 | 198 | def flash_attn(q, k, v, device='cpu', BLOCK_M=4): 199 | ''' 200 | Memory effective flash attention implement 201 | ''' 202 | flash_attn_v2_multihead(q, k, v, device, BLOCK_M=BLOCK_M) 203 | 204 | 205 | -------------------------------------------------------------------------------- /flash_attention_cuda/standalone_src/self_attention_standalone.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CUDA_CHECK(condition) \ 7 | do { \ 8 | cudaError_t error = condition; \ 9 | if (error != cudaSuccess) { \ 10 | printf("CUDA_CHECK error in line %d of file %s \ 11 | : %s \n", \ 12 | __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ 13 | exit(EXIT_FAILURE); \ 14 | } \ 15 | } while (0) 16 | 17 | #define DEBUG 18 | 19 | #ifdef DEBUG 20 | #define DEBUG_BLOCK(expr) \ 21 | do { \ 22 | expr \ 23 | } while (0) 24 | #else 25 | #define DEBUG_BLOCK(...) \ 26 | do { \ 27 | } while (0) 28 | #endif 29 | 30 | // seqlen 31 | const int input_seq = 4; 32 | // dim 33 | const int dim = 4; 34 | 35 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 36 | int M, int N, int K, int mBlock); 37 | __global__ void row_softmax(float *input, float *output, int n); 38 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 39 | int mBlock); 40 | void print_host_matrix(float *matrix, int m, int n); 41 | void print_device_matrix(float *matrix, int m, int n); 42 | 43 | void self_attention_cuda(float *Q, float *K, float *V, float *O, int m, int n) { 44 | int mBlock = 2; 45 | assert(m % mBlock == 0 && "mBlock should align"); 46 | 47 | float sm_scale = 1.f / sqrtf(static_cast(n)); 48 | float *sm_o; 49 | cudaMalloc((void **)&sm_o, sizeof(float) * m * m); 50 | 51 | dim3 qk_block(m / mBlock, 1, 1); 52 | naive_nrow_gemm<<<1, qk_block>>>(Q, K, sm_o, sm_scale, 0, m, m, n, mBlock); 53 | cudaDeviceSynchronize(); 54 | DEBUG_BLOCK( 55 | CUDA_CHECK(cudaGetLastError()); 56 | print_device_matrix(sm_o, m, m); 57 | ); 58 | 59 | // QK[M, M] 60 | dim3 sm_block(m, 1, 1); 61 | row_softmax<<<1, sm_block>>>(sm_o, sm_o, m); 62 | cudaDeviceSynchronize(); 63 | DEBUG_BLOCK( 64 | CUDA_CHECK(cudaGetLastError()); 65 | print_device_matrix(sm_o, m, m); 66 | ); 67 | 68 | // QK[M, M] @ V[M, N] 69 | dim3 qkv_block(m / mBlock, 1, 1); 70 | naive_pv<<<1, qkv_block>>>(sm_o, V, O, m, n, mBlock); 71 | cudaDeviceSynchronize(); 72 | DEBUG_BLOCK( 73 | CUDA_CHECK(cudaGetLastError()); 74 | print_device_matrix(O, m, n); 75 | ); 76 | 77 | cudaFree(sm_o); 78 | } 79 | 80 | // naive gemm implement with slice-k 81 | // perform C = aA@B + bC 82 | // A[M, K] x B[K, N] = C[M, N] 83 | // each thread process mblock rows of A 84 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 85 | int M, int N, int K, int mBlock) { 86 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 87 | 88 | // each thread process a range of rows 89 | idx *= mBlock; 90 | 91 | // A[mBlock, K] x B[N, K].T = C[mBlock, N] 92 | for (int i = idx; i < idx + mBlock; i++) { 93 | for (int j = 0; j < N; j++) { 94 | float sum = 0.f; 95 | for (int k = 0; k < K; k++) { 96 | sum += A[i * K + k] * B[j * K + k]; 97 | } 98 | // C[M, N] 99 | // C = aA@B + bC 100 | C[i * N + j] = a * sum + b * C[i * N + j]; 101 | } 102 | } 103 | } 104 | 105 | // perform QK[M, M] @ V[M, N] 106 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 107 | int mBlock) { 108 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 109 | 110 | // each thread process a range of rows 111 | idx *= mBlock; 112 | 113 | int K = M; 114 | // P[mBlock, M] x V[M, N] = O[mBlock, N] 115 | for (int i = idx; i < idx + mBlock; i++) { 116 | for (int j = 0; j < N; j++) { 117 | float sum = 0.f; 118 | for (int k = 0; k < K; k++) { 119 | sum += P[i * K + k] * V[k * N + j]; 120 | } 121 | // C[M, N] 122 | O[i * N + j] = sum; 123 | } 124 | } 125 | } 126 | 127 | // each thread process one row of softmax 128 | __global__ void row_softmax(float *input, float *output, int n) { 129 | // assume id will not exceed row number of input 130 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 131 | 132 | float max = -INFINITY; 133 | float sum = 0.f; 134 | 135 | // Find max 136 | for (int i = 0; i < n; i++) { 137 | if (input[idx * n + i] > max) { 138 | max = input[idx * n + i]; 139 | } 140 | } 141 | 142 | // Compute numerator and denominator 143 | for (int i = 0; i < n; i++) { 144 | output[idx * n + i] = exp(input[idx * n + i] - max); 145 | sum += output[idx * n + i]; 146 | } 147 | 148 | // Compute softmax 149 | for (int i = 0; i < n; i++) { 150 | output[idx * n + i] /= sum; 151 | } 152 | } 153 | 154 | // print matrix 155 | void print_host_matrix(float *matrix, int m, int n) { 156 | for (int i = 0; i < m; i++) { 157 | for (int j = 0; j < n; j++) { 158 | printf("%f, ", matrix[i * n + j]); 159 | } 160 | printf("\n"); 161 | } 162 | } 163 | 164 | void print_device_matrix(float *dev_ptr, int m, int n) { 165 | float *host_ptr = new float[m * n]; 166 | cudaMemcpy(host_ptr, dev_ptr, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 167 | 168 | for (int i = 0; i < m; i++) { 169 | for (int j = 0; j < n; j++) { 170 | printf("%f, ", host_ptr[i * n + j]); 171 | } 172 | printf("\n"); 173 | } 174 | free(host_ptr); 175 | } 176 | 177 | void test_attention() { 178 | // seqlen 179 | int m = input_seq; 180 | // dim 181 | int n = dim; 182 | 183 | // Host pointer 184 | float *h_K = new float[m * n]; 185 | float *h_Q = new float[m * n]; 186 | float *h_V = new float[m * n]; 187 | float *h_O = new float[m * n]; 188 | 189 | // 初始化 K, Q, V 190 | for (int i = 0; i < m * n; ++i) { 191 | // h_K[i] = static_cast(rand()) / RAND_MAX; 192 | // h_Q[i] = static_cast(rand()) / RAND_MAX; 193 | // h_V[i] = static_cast(rand()) / RAND_MAX; 194 | h_K[i] = static_cast(i); 195 | h_Q[i] = static_cast(i); 196 | h_V[i] = static_cast(i); 197 | } 198 | 199 | float *d_K, *d_Q, *d_V, *d_O; 200 | // Malloc device memory 201 | cudaMalloc((void **)&d_K, sizeof(float) * m * n); 202 | cudaMalloc((void **)&d_Q, sizeof(float) * m * n); 203 | cudaMalloc((void **)&d_V, sizeof(float) * m * n); 204 | cudaMalloc((void **)&d_O, sizeof(float) * m * n); 205 | 206 | // Copy data from host to device 207 | cudaMemcpy(d_K, h_K, sizeof(float) * m * n, cudaMemcpyHostToDevice); 208 | cudaMemcpy(d_Q, h_Q, sizeof(float) * m * n, cudaMemcpyHostToDevice); 209 | cudaMemcpy(d_V, h_V, sizeof(float) * m * n, cudaMemcpyHostToDevice); 210 | 211 | cudaEvent_t start, stop; 212 | cudaEventCreate(&start); 213 | cudaEventCreate(&stop); 214 | cudaEventRecord(start, 0); 215 | 216 | // Run test 217 | for (int i = 0; i < 1; i++) { 218 | // Launch kernel 219 | self_attention_cuda(d_Q, d_K, d_V, d_O, m, n); 220 | 221 | CUDA_CHECK(cudaGetLastError()); 222 | } 223 | 224 | cudaEventRecord(stop, 0); 225 | cudaEventSynchronize(stop); 226 | float milliseconds = 0; 227 | cudaEventElapsedTime(&milliseconds, start, stop); 228 | printf("Time for kernel execution: %.3f ms \n", milliseconds / 100); 229 | cudaEventDestroy(start); 230 | cudaEventDestroy(stop); 231 | 232 | // Result back to host 233 | cudaMemcpy(h_O, d_O, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 234 | 235 | cudaFree(d_K); 236 | cudaFree(d_Q); 237 | cudaFree(d_V); 238 | cudaFree(d_O); 239 | free(h_Q); 240 | free(h_K); 241 | free(h_V); 242 | free(h_O); 243 | } 244 | 245 | int main() { 246 | test_attention(); 247 | 248 | return 0; 249 | } 250 | -------------------------------------------------------------------------------- /flash_attention_cutlass/csrc/kernel_traits.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cute/algorithm/copy.hpp" 4 | 5 | #include "cutlass/cutlass.h" 6 | #include "cutlass/layout/layout.h" 7 | #include 8 | 9 | using namespace cute; 10 | 11 | template 12 | struct Flash_kernel_traits { 13 | 14 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 15 | using Element = elem_type; 16 | static constexpr bool Has_cp_async = true; 17 | #else 18 | using Element = cutlass::half_t; 19 | static constexpr bool Has_cp_async = false; 20 | #endif 21 | 22 | using ElementAccum = float; 23 | using index_t = uint32_t; 24 | 25 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 26 | using MMA_Atom_Arch = std::conditional_t< 27 | std::is_same_v, 28 | MMA_Atom, 29 | MMA_Atom 30 | >; 31 | using ValLayoutMNK = Layout>; 32 | #else 33 | using MMA_Atom_Arch = MMA_Atom; 34 | using ValLayoutMNK = Layout>; 35 | #endif 36 | 37 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 38 | using SmemCopyAtom = Copy_Atom; 39 | using SmemCopyAtomTransposed = Copy_Atom; 40 | #else 41 | using SmemCopyAtom = Copy_Atom; 42 | using SmemCopyAtomTransposed = Copy_Atom; 43 | #endif 44 | }; 45 | 46 | 47 | // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true 48 | template > 50 | struct Flash_fwd_kernel_traits : public Base { 51 | using Element = typename Base::Element; 52 | using ElementAccum = typename Base::ElementAccum; 53 | using index_t = typename Base::index_t; 54 | static constexpr bool Has_cp_async = Base::Has_cp_async; 55 | using SmemCopyAtom = typename Base::SmemCopyAtom; 56 | using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 57 | 58 | // The number of threads. 59 | static constexpr int kNWarps = kNWarps_; 60 | static constexpr int kNThreads = kNWarps * 32; 61 | 62 | static constexpr int kBlockM = kBlockM_; 63 | static constexpr int kBlockN = kBlockN_; 64 | static constexpr int kHeadDim = kHeadDim_; 65 | 66 | // TODO: review 67 | static_assert(kHeadDim % 32 == 0); 68 | static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 69 | static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); 70 | static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 71 | 72 | using TiledMma = TiledMMA< 73 | typename Base::MMA_Atom_Arch, 74 | Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group 75 | // NOTE: cutlass v3.3 76 | // typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM 77 | // cutlass v3.4 78 | Tile, _16, _16>>; 79 | 80 | using SmemLayoutAtomQ = decltype( 81 | composition(Swizzle{}, 82 | // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 83 | Layout>, 84 | Stride, _1>>{})); 85 | using SmemLayoutQ = decltype(tile_to_shape( 86 | SmemLayoutAtomQ{}, 87 | Shape, Int>{})); 88 | 89 | using SmemLayoutKV = decltype(tile_to_shape( 90 | SmemLayoutAtomQ{}, 91 | Shape, Int>{})); 92 | 93 | // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 94 | using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, 95 | Stride<_1, Int>>; 96 | using SmemLayoutAtomVtransposed = decltype( 97 | composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); 98 | using SmemLayoutVtransposed = decltype(tile_to_shape( 99 | SmemLayoutAtomVtransposed{}, 100 | Shape, Int>{})); 101 | // Maybe the VtransposeNoSwizzle just needs to have the right shape 102 | // And the strides don't matter? 103 | using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( 104 | SmemLayoutAtomVtransposedNoSwizzle{}, 105 | Shape, Int>{})); 106 | // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); 107 | 108 | using SmemLayoutAtomO = decltype( 109 | composition(Swizzle{}, 110 | Layout, Int>, 111 | Stride, _1>>{})); 112 | using SmemLayoutO = decltype(tile_to_shape( 113 | SmemLayoutAtomO{}, 114 | Shape, Int>{})); 115 | using SmemCopyAtomO = Copy_Atom; 116 | using SmemCopyAtomOaccum = Copy_Atom; 117 | 118 | static constexpr int kSmemQCount = size(SmemLayoutQ{}); 119 | static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; 120 | static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); 121 | static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); 122 | // TODO: 123 | static constexpr int kSmemSize = kSmemQSize + kSmemKVSize; 124 | 125 | static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 126 | static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); 127 | 128 | // TODO: review 129 | 130 | // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. 131 | // For example, for d=128, smem is split into 2 "pages", each page takes care of columns 132 | // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, 133 | // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, 134 | // to the same banks. 135 | static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 136 | static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); 137 | using GmemLayoutAtom = Layout, Int>, 138 | Stride, _1>>; 139 | 140 | // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 141 | // from the same address by the same threadblock. This is slightly faster. 142 | using Gmem_copy_struct = std::conditional_t< 143 | Has_cp_async, 144 | SM80_CP_ASYNC_CACHEGLOBAL, 145 | DefaultCopy 146 | >; 147 | using GmemTiledCopyQKV = decltype( 148 | make_tiled_copy(Copy_Atom{}, 149 | GmemLayoutAtom{}, 150 | Layout>{})); // Val layout, 8 vals per read 151 | using GmemTiledCopyO = decltype( 152 | make_tiled_copy(Copy_Atom{}, 153 | GmemLayoutAtom{}, 154 | Layout>{})); // Val layout, 8 vals per store 155 | static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; 156 | static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); 157 | using GmemLayoutAtomP = Layout, Int>, 158 | Stride, _1>>; 159 | 160 | using GmemTiledCopyP = decltype( 161 | make_tiled_copy(Copy_Atom{}, 162 | GmemLayoutAtomP{}, 163 | Layout>{})); // Val layout, 8 vals per store 164 | 165 | using GmemLayoutAtomOaccum = std::conditional_t< 166 | kBlockKSmem == 32, 167 | Layout, // Thread layout, 8 threads per row 168 | Stride< _8, _1>>, 169 | Layout, // Thread layout, 16 threads per row 170 | Stride< _16, _1>> 171 | >; 172 | using GmemTiledCopyOaccum = decltype( 173 | make_tiled_copy(Copy_Atom{}, 174 | GmemLayoutAtomOaccum{}, 175 | Layout>{})); // Val layout, 4 vals per store 176 | using GmemLayoutAtomRotcossin = GmemLayoutAtom; 177 | using GmemTiledCopyRotcossin = decltype( 178 | make_tiled_copy(Copy_Atom, Element>{}, 179 | GmemLayoutAtomRotcossin{}, 180 | Layout>{})); // Val layout, 4 vals per load 181 | using GmemTiledCopyRotcossinCont = decltype( 182 | make_tiled_copy(Copy_Atom{}, 183 | GmemLayoutAtomRotcossin{}, 184 | Layout>{})); // Val layout, 8 vals per load 185 | 186 | }; 187 | -------------------------------------------------------------------------------- /flash_attention_cutlass/standalone_src/kernel_traits.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cute/algorithm/copy.hpp" 4 | 5 | #include "cutlass/cutlass.h" 6 | #include "cutlass/layout/layout.h" 7 | #include 8 | 9 | using namespace cute; 10 | 11 | template 12 | struct Flash_kernel_traits { 13 | 14 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 15 | using Element = elem_type; 16 | static constexpr bool Has_cp_async = true; 17 | #else 18 | using Element = cutlass::half_t; 19 | static constexpr bool Has_cp_async = false; 20 | #endif 21 | 22 | using ElementAccum = float; 23 | using index_t = uint32_t; 24 | 25 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 26 | using MMA_Atom_Arch = std::conditional_t< 27 | std::is_same_v, 28 | MMA_Atom, 29 | MMA_Atom 30 | >; 31 | using ValLayoutMNK = Layout>; 32 | #else 33 | using MMA_Atom_Arch = MMA_Atom; 34 | using ValLayoutMNK = Layout>; 35 | #endif 36 | 37 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 38 | using SmemCopyAtom = Copy_Atom; 39 | using SmemCopyAtomTransposed = Copy_Atom; 40 | #else 41 | using SmemCopyAtom = Copy_Atom; 42 | using SmemCopyAtomTransposed = Copy_Atom; 43 | #endif 44 | }; 45 | 46 | 47 | // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true 48 | template > 50 | struct Flash_fwd_kernel_traits : public Base { 51 | using Element = typename Base::Element; 52 | using ElementAccum = typename Base::ElementAccum; 53 | using index_t = typename Base::index_t; 54 | static constexpr bool Has_cp_async = Base::Has_cp_async; 55 | using SmemCopyAtom = typename Base::SmemCopyAtom; 56 | using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 57 | 58 | // The number of threads. 59 | static constexpr int kNWarps = kNWarps_; 60 | static constexpr int kNThreads = kNWarps * 32; 61 | 62 | static constexpr int kBlockM = kBlockM_; 63 | static constexpr int kBlockN = kBlockN_; 64 | static constexpr int kHeadDim = kHeadDim_; 65 | 66 | // TODO: review 67 | static_assert(kHeadDim % 32 == 0); 68 | static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 69 | static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); 70 | static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 71 | 72 | using TiledMma = TiledMMA< 73 | typename Base::MMA_Atom_Arch, 74 | Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group 75 | // NOTE: cutlass v3.3 76 | // typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM 77 | // cutlass v3.4 78 | Tile, _16, _16>>; 79 | 80 | using SmemLayoutAtomQ = decltype( 81 | composition(Swizzle{}, 82 | // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 83 | Layout>, 84 | Stride, _1>>{})); 85 | using SmemLayoutQ = decltype(tile_to_shape( 86 | SmemLayoutAtomQ{}, 87 | Shape, Int>{})); 88 | 89 | using SmemLayoutKV = decltype(tile_to_shape( 90 | SmemLayoutAtomQ{}, 91 | Shape, Int>{})); 92 | 93 | // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 94 | using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, 95 | Stride<_1, Int>>; 96 | using SmemLayoutAtomVtransposed = decltype( 97 | composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); 98 | using SmemLayoutVtransposed = decltype(tile_to_shape( 99 | SmemLayoutAtomVtransposed{}, 100 | Shape, Int>{})); 101 | // Maybe the VtransposeNoSwizzle just needs to have the right shape 102 | // And the strides don't matter? 103 | using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( 104 | SmemLayoutAtomVtransposedNoSwizzle{}, 105 | Shape, Int>{})); 106 | // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); 107 | 108 | using SmemLayoutAtomO = decltype( 109 | composition(Swizzle{}, 110 | Layout, Int>, 111 | Stride, _1>>{})); 112 | using SmemLayoutO = decltype(tile_to_shape( 113 | SmemLayoutAtomO{}, 114 | Shape, Int>{})); 115 | using SmemCopyAtomO = Copy_Atom; 116 | using SmemCopyAtomOaccum = Copy_Atom; 117 | 118 | static constexpr int kSmemQCount = size(SmemLayoutQ{}); 119 | static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; 120 | static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); 121 | static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); 122 | // TODO: 123 | static constexpr int kSmemSize = kSmemQSize + kSmemKVSize; 124 | 125 | static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 126 | static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); 127 | 128 | // TODO: review 129 | 130 | // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. 131 | // For example, for d=128, smem is split into 2 "pages", each page takes care of columns 132 | // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, 133 | // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, 134 | // to the same banks. 135 | static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 136 | static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); 137 | using GmemLayoutAtom = Layout, Int>, 138 | Stride, _1>>; 139 | 140 | // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 141 | // from the same address by the same threadblock. This is slightly faster. 142 | using Gmem_copy_struct = std::conditional_t< 143 | Has_cp_async, 144 | SM80_CP_ASYNC_CACHEGLOBAL, 145 | DefaultCopy 146 | >; 147 | using GmemTiledCopyQKV = decltype( 148 | make_tiled_copy(Copy_Atom{}, 149 | GmemLayoutAtom{}, 150 | Layout>{})); // Val layout, 8 vals per read 151 | using GmemTiledCopyO = decltype( 152 | make_tiled_copy(Copy_Atom{}, 153 | GmemLayoutAtom{}, 154 | Layout>{})); // Val layout, 8 vals per store 155 | static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; 156 | static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); 157 | using GmemLayoutAtomP = Layout, Int>, 158 | Stride, _1>>; 159 | 160 | using GmemTiledCopyP = decltype( 161 | make_tiled_copy(Copy_Atom{}, 162 | GmemLayoutAtomP{}, 163 | Layout>{})); // Val layout, 8 vals per store 164 | 165 | using GmemLayoutAtomOaccum = std::conditional_t< 166 | kBlockKSmem == 32, 167 | Layout, // Thread layout, 8 threads per row 168 | Stride< _8, _1>>, 169 | Layout, // Thread layout, 16 threads per row 170 | Stride< _16, _1>> 171 | >; 172 | using GmemTiledCopyOaccum = decltype( 173 | make_tiled_copy(Copy_Atom{}, 174 | GmemLayoutAtomOaccum{}, 175 | Layout>{})); // Val layout, 4 vals per store 176 | using GmemLayoutAtomRotcossin = GmemLayoutAtom; 177 | using GmemTiledCopyRotcossin = decltype( 178 | make_tiled_copy(Copy_Atom, Element>{}, 179 | GmemLayoutAtomRotcossin{}, 180 | Layout>{})); // Val layout, 4 vals per load 181 | using GmemTiledCopyRotcossinCont = decltype( 182 | make_tiled_copy(Copy_Atom{}, 183 | GmemLayoutAtomRotcossin{}, 184 | Layout>{})); // Val layout, 8 vals per load 185 | 186 | }; 187 | -------------------------------------------------------------------------------- /flash_attention_c/csrc/attn.cpp: -------------------------------------------------------------------------------- 1 | #include "attn.h" 2 | #include "utils.h" 3 | #include 4 | #include 5 | 6 | struct attn_fwd_params { 7 | size_t bs; 8 | size_t head_num; 9 | // TODO: GQA support 10 | size_t q_seqlen; 11 | size_t head_dim; 12 | size_t k_seqlen; 13 | size_t kv_head_num; 14 | 15 | size_t stride_q_bs; 16 | size_t stride_q_head_num; 17 | size_t stride_q_seqlen; 18 | size_t stride_q_head_dim; 19 | 20 | size_t stride_kv_bs; 21 | size_t stride_kv_head_num; 22 | size_t stride_kv_seqlen; 23 | size_t stride_kv_head_dim; 24 | 25 | void *q_ptr; 26 | void *k_ptr; 27 | void *v_ptr; 28 | void *o_ptr; 29 | 30 | bool is_causal; 31 | float softmax_scale; 32 | }; 33 | 34 | 35 | template void run_naive_attn(attn_fwd_params ¶ms, typename Attn_traits::elem_type* attn_score, size_t stride_score_l1) { 36 | /* 37 | q k v.shape (bs, head_num, seqlen, head_dim) 38 | attn_score.shape = (seqlen, seqlen), compute one by one 39 | */ 40 | using elem_type = typename Attn_traits::elem_type; 41 | 42 | for (int bid = 0; bid < params.bs; bid++) { 43 | for (int hid = 0; hid < params.head_num; hid++) { 44 | #pragma omp parallel for 45 | for (int i = 0; i < params.q_seqlen; i++) { 46 | elem_type* q = static_cast(params.q_ptr) + bid * params.stride_q_bs + hid * params.stride_q_head_num + i * params.stride_q_seqlen; 47 | 48 | float maxval = -INFINITY; 49 | 50 | int kv_len = params.k_seqlen; 51 | if (params.is_causal) { 52 | kv_len = i + 1 + (params.k_seqlen - params.q_seqlen); 53 | } 54 | 55 | // qk dot product 56 | for (int j = 0; j < kv_len; j++) { 57 | elem_type* k = static_cast(params.k_ptr) + bid * params.stride_kv_bs + hid * params.stride_kv_head_num + j * params.stride_kv_seqlen; 58 | elem_type val = 0.0f; 59 | for (int dim = 0; dim < params.head_dim; dim++) { 60 | val += q[dim] * k[dim]; 61 | } 62 | 63 | val *= params.softmax_scale; 64 | if (val > maxval) { 65 | maxval = val; 66 | } 67 | // set score[i, j] 68 | attn_score[i * stride_score_l1 + j] = val; 69 | } 70 | 71 | // NOTE: softmax 72 | float score_sum = 0.0f; 73 | for (int j = 0; j < kv_len; j++) { 74 | auto exp = expf(attn_score[i * stride_score_l1 + j] - maxval); 75 | score_sum += exp; 76 | attn_score[i * stride_score_l1 + j] = exp; 77 | } 78 | for (int j = 0; j < kv_len; j++) { 79 | attn_score[i * stride_score_l1 + j] /= score_sum; 80 | } 81 | 82 | // NOTE: compute qk @ v 83 | // (seqlen, seqlen) @ (seqlen, head_dim) 84 | elem_type* out = static_cast(params.o_ptr) + bid * params.stride_q_bs + hid * params.stride_q_head_num + i * params.stride_q_seqlen; 85 | // init accumulators 86 | for (int dim = 0; dim < params.head_dim; dim++) { 87 | out[dim] = 0.0f; 88 | } 89 | for (int j = 0; j < kv_len; j++) { 90 | elem_type* v = static_cast(params.v_ptr) + bid * params.stride_kv_bs + hid * params.stride_kv_head_num + j * params.stride_kv_seqlen; 91 | for (int dim = 0; dim < params.head_dim; dim++) { 92 | out[dim] += attn_score[i * stride_score_l1 + j] * v[dim]; 93 | } 94 | } 95 | } 96 | } 97 | } 98 | } 99 | 100 | 101 | template void run_flash_attn(attn_fwd_params ¶ms) { 102 | /* 103 | q k v.shape (bs, head_num, seqlen, head_dim) 104 | attn_score.shape = (seqlen, seqlen), compute one by one 105 | */ 106 | using elem_type = typename Attn_traits::elem_type; 107 | 108 | #pragma omp parallel for collapse(3) 109 | for (int bid = 0; bid < params.bs; bid++) { 110 | for (int hid = 0; hid < params.head_num; hid++) { 111 | for (int i = 0; i < params.q_seqlen; i++) { 112 | elem_type* q = static_cast(params.q_ptr) + bid * params.stride_q_bs + hid * params.stride_q_head_num + i * params.stride_q_seqlen; 113 | // init accumulators with zero allocate 114 | elem_type* out = static_cast(params.o_ptr) + bid * params.stride_q_bs + hid * params.stride_q_head_num + i * params.stride_q_seqlen; 115 | // history max 116 | float maxval = -INFINITY; 117 | // div delay till the end (only div once) 118 | float score_sum = 0.0f; 119 | // qk dot product 120 | // NOTE: and online softmax 121 | int kv_len = params.k_seqlen; 122 | if (params.is_causal) { 123 | kv_len = i + 1 + (params.k_seqlen - params.q_seqlen); 124 | } 125 | for (int j = 0; j < kv_len; j++) { 126 | float local_maxval = -INFINITY; 127 | elem_type* k = static_cast(params.k_ptr) + bid * params.stride_kv_bs + hid * params.stride_kv_head_num + j * params.stride_kv_seqlen; 128 | // TODO: val need should be higher precision 129 | elem_type val = 0.0f; 130 | 131 | // q @ k 132 | for (int dim = 0; dim < params.head_dim; dim++) { 133 | val += q[dim] * k[dim]; 134 | } 135 | val *= params.softmax_scale; 136 | 137 | // local_maxval always the real max 138 | local_maxval = std::max(maxval, val); 139 | 140 | // TODO: skip scale if no update? 141 | // TODO: exp2f? 142 | auto exp = expf(val - local_maxval); 143 | auto scale = expf(maxval - local_maxval); 144 | 145 | // rescale score sum 146 | score_sum *= scale; 147 | score_sum += exp; 148 | 149 | // NOTE: online softmax rescale, update 150 | // and compute qk @ v: (seqlen, seqlen) @ (seqlen, head_dim) 151 | elem_type* v = static_cast(params.v_ptr) + bid * params.stride_kv_bs + hid * params.stride_kv_head_num + j * params.stride_kv_seqlen; 152 | for (int dim = 0; dim < params.head_dim; dim++) { 153 | // rescale score 154 | out[dim] *= scale; 155 | out[dim] += exp * v[dim]; 156 | } 157 | 158 | // update max 159 | maxval = local_maxval; 160 | } 161 | 162 | // TODO: online rescale or delay till the end? 163 | for (int dim = 0; dim < params.head_dim; dim++) { 164 | out[dim] /= score_sum; 165 | } 166 | } 167 | } 168 | } 169 | } 170 | 171 | void set_params_fprop(attn_fwd_params ¶ms, 172 | // device pointers 173 | const torch::Tensor q, 174 | const torch::Tensor k, 175 | const torch::Tensor v, 176 | torch::Tensor out, 177 | bool is_causal, 178 | float softmax_scale) { 179 | params.bs = q.size(0); 180 | params.head_num = q.size(1); 181 | params.kv_head_num = k.size(1); 182 | params.q_seqlen = q.size(2); 183 | params.k_seqlen = k.size(2); 184 | params.head_dim = q.size(3); 185 | 186 | params.stride_q_bs = q.stride(0); 187 | params.stride_q_head_num = q.stride(1); 188 | params.stride_q_seqlen = q.stride(2); 189 | params.stride_q_head_dim = q.stride(3); 190 | 191 | params.stride_kv_bs = k.stride(0); 192 | params.stride_kv_head_num = k.stride(1); 193 | params.stride_kv_seqlen = k.stride(2); 194 | params.stride_kv_head_dim = k.stride(3); 195 | 196 | params.q_ptr = q.data_ptr(); 197 | params.k_ptr = k.data_ptr(); 198 | params.v_ptr = v.data_ptr(); 199 | params.o_ptr = out.data_ptr(); 200 | 201 | params.is_causal = is_causal; 202 | params.softmax_scale = softmax_scale; 203 | } 204 | 205 | 206 | 207 | torch::Tensor naive_attn(torch::Tensor q, torch::Tensor k, 208 | torch::Tensor v, bool is_causal = false, float softmax_scale=1) { 209 | TORCH_CHECK(q.device().is_cpu(), "q must be on CPU"); 210 | TORCH_CHECK(k.device().is_cpu(), "k must be on CPU"); 211 | TORCH_CHECK(v.device().is_cpu(), "v must be on CPU"); 212 | 213 | // batch size 214 | int bs = q.size(0); 215 | // head number 216 | int head = q.size(1); 217 | // seqlen 218 | int seqlen = q.size(2); 219 | int kv_seqlen = k.size(2); 220 | // dim 221 | int dim = q.size(3); 222 | 223 | auto attn_score = torch::empty({seqlen, kv_seqlen}, q.options()); 224 | auto out = torch::zeros_like(q); 225 | 226 | attn_fwd_params params; 227 | set_params_fprop(params, q, k, v, out, 228 | is_causal, softmax_scale); 229 | 230 | // TODO: hard code float 231 | run_naive_attn>(params, (float*)attn_score.data_ptr(), attn_score.stride(0)); 232 | 233 | return out; 234 | } 235 | 236 | 237 | torch::Tensor flash_attn(torch::Tensor q, torch::Tensor k, 238 | torch::Tensor v, bool is_causal = false, float softmax_scale=1) { 239 | TORCH_CHECK(q.device().is_cpu(), "q must be on CPU"); 240 | TORCH_CHECK(k.device().is_cpu(), "k must be on CPU"); 241 | TORCH_CHECK(v.device().is_cpu(), "v must be on CPU"); 242 | 243 | // batch size 244 | int bs = q.size(0); 245 | // head number 246 | int head = q.size(1); 247 | // seqlen 248 | int seqlen = q.size(2); 249 | // dim 250 | int dim = q.size(3); 251 | 252 | auto out = torch::zeros_like(q); 253 | 254 | attn_fwd_params params; 255 | set_params_fprop(params, q, k, v, out, 256 | is_causal, softmax_scale); 257 | 258 | // TODO: hard code float 259 | run_flash_attn>(params); 260 | 261 | return out; 262 | } 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /flash_attention_py/tiny_flash_attn_triton.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py 2 | # 3 | # https://github.com/kyegomez/FlashAttention20Triton 4 | 5 | from torch import float32 6 | import torch 7 | import time 8 | import triton 9 | import triton.language as tl 10 | 11 | def flash_attn_triton(q, k, v, causal=True, sm_scale=1): 12 | # shape constraints 13 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 14 | assert Lq == Lk and Lk == Lv 15 | assert Lk in {16, 32, 64, 128} 16 | 17 | o = torch.empty_like(q) 18 | 19 | BLOCK_M = 128 20 | BLOCK_N = 64 21 | # NOTE: 对于flash attention 2, 外层循环的q可以并行处理, 因此每个thread需要计算正确的offset 22 | # 一个q, k, v的shape往往是(bs, head, seqlen, dim) 23 | # 对于(bs, head)中的每个元素都分配一个thread 24 | # 对于seqlen / BLOCK_M个的q分块, 每个分块再分配一个thread 25 | grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) 26 | # NOTE: 27 | # L.shape = (bs * head, seqlen) 28 | # L记录了所有的分母和mi(m_i + tl.math.log2(l_i)), 用于后续的backward 29 | L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) 30 | # 设置适当的wrap以提升性能 31 | num_warps = 4 if Lk <= 64 else 8 32 | _fwd_kernel[grid]( 33 | q, k, v, sm_scale, 34 | L, 35 | o, 36 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 37 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 38 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 39 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 40 | q.shape[0], q.shape[1], q.shape[2], 41 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, DIM=Lk, 42 | IS_CAUSAL=causal, 43 | num_warps=num_warps, 44 | num_stages=4) 45 | 46 | return o 47 | 48 | 49 | @triton.jit 50 | def _fwd_kernel( 51 | Q, K, V, sm_scale, 52 | # L记录了所有的分母和mi, 用于后续的backward 53 | L, 54 | O, 55 | stride_q_bs, stride_q_head, stride_q_seqlen, stride_q_dim, 56 | stride_k_bs, stride_k_head, stride_k_seqlen, stride_k_dim, 57 | stride_v_bs, stride_v_head, stride_v_seqlen, stride_v_dim, 58 | stride_o_bs, stride_o_head, stride_o_seqlen, stride_o_dim, 59 | BS, HEAD, SEQLEN, 60 | # BLOCK_M用于做Q的分块 61 | BLOCK_M: tl.constexpr, 62 | DIM: tl.constexpr, 63 | # BLOCK_N用于做K和V的分块 64 | BLOCK_N: tl.constexpr, 65 | IS_CAUSAL: tl.constexpr, 66 | ): 67 | # grid = (cdiv(seqlen, BLOCK_M), bs * head) 68 | # triton.language.program_id(axis) axis is The axis of the 3D launch grid 69 | # Q分块的起始地址 70 | start_m = tl.program_id(0) 71 | # 跳过(bs, head)的偏移 72 | off_bs_head = tl.program_id(1) 73 | 74 | # NOTE: 75 | # base = off_bs_head * stride_q_head找到正确的(bs, head)位置 76 | # strides: 步长, advance时直接使用步数, 会自动根据步长计算跳过的元素 77 | # offsets表示parent block (seqlen, dim)中怎么偏移来获取小块 78 | # block_shape=(BLOCK_M, DIM)表示parent block的shape 79 | # order表示用什么顺序读取存储来构造所需的shape 80 | qkv_base_offset = off_bs_head * stride_q_head 81 | Q_block_ptr = tl.make_block_ptr( 82 | # base offset to skip to the right (bs, head) 83 | base=Q + qkv_base_offset, 84 | # the shape of parent 85 | shape=(SEQLEN, DIM), 86 | strides=(stride_q_seqlen, stride_q_dim), 87 | # offset of the block inside of parent block 88 | offsets=(start_m * BLOCK_M, 0), 89 | block_shape=(BLOCK_M, DIM), 90 | order=(1, 0), 91 | ) 92 | K_block_ptr = tl.make_block_ptr( 93 | # base offset to skip to the right (bs, head) 94 | base=K + qkv_base_offset, 95 | # the shape of parent 96 | # NOTE: make_block_ptr读入时将K转置了 97 | shape=(DIM, SEQLEN), 98 | strides=(stride_k_dim, stride_k_seqlen), 99 | # 每个Q需要遍历整个的k和v 100 | offsets=(0, 0), 101 | # K根据BLOCK_N分块 102 | block_shape=(DIM, BLOCK_N), 103 | # 读入K的转置 104 | order=(0, 1), 105 | ) 106 | V_block_ptr = tl.make_block_ptr( 107 | # base offset to skip to the right (bs, head) 108 | base=V + qkv_base_offset, 109 | # the shape of parent 110 | shape=(SEQLEN, DIM), 111 | strides=(stride_k_seqlen, stride_v_dim), 112 | # 每个Q需要遍历整个的k和v 113 | offsets=(0, 0), 114 | # K根据BLOCK_N分块 115 | block_shape=(BLOCK_N, DIM), 116 | order=(1, 0), 117 | ) 118 | # initialize offsets 119 | # NOTE: BLOCK_M表示Q的分块大小, BLOCK_N表示k, v的分块大小 120 | off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 121 | off_n = tl.arange(0, BLOCK_N) 122 | # initialize pointers 123 | # NOTE: 一次处理一个(BLOCK_M, dim)的q, 而max和分母的sum都只需要一维, 即(BLOCK_M, 1) 124 | max = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') 125 | # 分母累加的sum, 每行的sum是一样的, 所以只需要一维然后广播即可 126 | denom = tl.zeros([BLOCK_M], dtype=tl.float32) 127 | out_buffer = tl.zeros([BLOCK_M, DIM], dtype=tl.float32) 128 | # NOTE: 129 | # scale sm_scale by log_2(e) and use 130 | # 2^x instead of exp in the loop because CSE and LICM 131 | # don't work as expected with `exp` in the loop 132 | # CSE(common subexpression elimination), LICM(loop invariant code motion)是编译器里的东西 133 | qk_scale = sm_scale * 1.44269504 134 | # load q: stay in SRAM throughout 135 | q = tl.load(Q_block_ptr) 136 | q = (q * qk_scale).to(tl.float16) 137 | # loop over k, v and update accumulator 138 | lo = 0 139 | # NOTE:: CAUSAL就是常说的不能看到后面的文本的自回归模型 140 | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else SEQLEN 141 | # NOTE: 142 | # 当前q和0..seqlen的kv算attention 143 | # 每次批处理BLOCK_N个k, v(即k, v以BLOCK_N分块) 144 | for start_n in range(lo, hi, BLOCK_N): 145 | k = tl.load(K_block_ptr) 146 | v = tl.load(V_block_ptr) 147 | 148 | # compute qk 149 | # NOTE: q.shape = (BLOCK_M, dim), k.shape(已转置) = (dim, BLOCK_N) 150 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 151 | if IS_CAUSAL: 152 | qk = tl.where(off_m[:, None] >= (start_n + off_n[None, :]), qk, float("-inf")) 153 | # NOTE: 执行矩阵乘法(matrix product), k在make_block_ptr时已经转置 154 | # qk init as zero 155 | qk += tl.dot(q, k) 156 | 157 | # compute scaling constant 158 | 159 | # NOTE: 160 | # max.shape = [BLOCK_M], aka [BLOCK_M, 1] 161 | # qk.shape = [BLOCK_M, BLOCK_N] 162 | # tl.max(block, axis) 163 | # tl.maximum(block, block) 164 | max_new = tl.maximum(max, tl.max(qk, 1)) 165 | # 保存exp的值, 节省exp操作 166 | alpha = tl.math.exp2(max - max_new) 167 | # NOTE: 168 | # nume = e^{x - max(x)} 169 | # max.shape = [BLOCK_M], max_new[:, None]扩展成[BLOCK_M, 1]来做广播操作 170 | nume = tl.math.exp2(qk - max_new[:, None]) 171 | # scale and update acc 172 | # NOTE: 利用广播来快速构建scale用于更新分母 173 | out_scale = denom * 0 + alpha 174 | # NOTE: 175 | # out_scale.shape = l_i.shape = [BLOCK_M] 176 | # out_scale[:, None]扩展成[BLOCK_M, 1]来做广播操作 177 | # out_buffer = old_out * scale来更新分子 178 | out_buffer *= out_scale[:, None] 179 | out_buffer += tl.dot(nume.to(tl.float16), v) 180 | # update max and denominator 181 | denom = denom * alpha + tl.sum(nume, 1) 182 | max = max_new 183 | # update k v pointer 184 | # NOTE: 计算下一个k, v的分块 185 | # 因为k已经转置(dim, seqlen), 所以算下一批seq的k时是增加k的第二个维度 186 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 187 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 188 | 189 | # write back l and m for backward 190 | # 最后统一更新output buffer, 除上完整的分母 191 | out_buffer = out_buffer / denom[:, None] 192 | # NOTE: 将分母和mi保存到L中, 用于后续的backward 193 | # L.shape = (bs * head, seqlen), 因为每一行的分母和mi是相同的 194 | # off_bs_head = bs * head 195 | l_ptr = L + off_bs_head * SEQLEN + off_m 196 | # write [BLOCK_M] of data to L 197 | tl.store(l_ptr, max + tl.math.log2(denom)) 198 | # write back O 199 | O_block_ptr = tl.make_block_ptr( 200 | base=O + qkv_base_offset, 201 | shape=(SEQLEN, DIM), 202 | strides=(stride_o_seqlen, stride_o_dim), 203 | offsets=(start_m * BLOCK_M, 0), 204 | block_shape=(BLOCK_M, DIM), 205 | order=(1, 0), 206 | ) 207 | tl.store(O_block_ptr, out_buffer.to(tl.float16)) 208 | 209 | def ref_attn(q, k, v, causal=True, sm_scale=1): 210 | SEQLEN = q.shape[-2] 211 | M = torch.tril(torch.ones((SEQLEN, SEQLEN), device="cuda")) 212 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 213 | if causal: 214 | p[:, :, M == 0] = float("-inf") 215 | p = torch.softmax(p.float(), dim=-1).half() 216 | ref_out = torch.matmul(p, v) 217 | return ref_out 218 | 219 | def causal_test(BS, HEAD, SEQLEN, DIM, causal): 220 | dtype = torch.float16 221 | torch.manual_seed(20) 222 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 223 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 224 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 225 | sm_scale = 0.5 226 | 227 | # reference implementation 228 | time_ref = time.time() 229 | ref_out = ref_attn(q, k, v, causal=causal, sm_scale=sm_scale) 230 | time_ref = time.time() - time_ref 231 | 232 | # triton implementation 233 | time_tri = time.time() 234 | tri_out = flash_attn_triton(q, k, v, causal=causal, sm_scale=sm_scale).half() 235 | time_tri = time.time() - time_tri 236 | 237 | # compare 238 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 239 | print("causal = {} ref time: {:.4f} ms, tri time: {:.4f}".format(causal, time_ref * 1000, time_tri * 1000)) 240 | 241 | def test_attention(): 242 | BS, HEAD, SEQLEN, DIM = 1, 2, 1024, 64 243 | causal_test(BS, HEAD, SEQLEN, DIM, causal=False) 244 | causal_test(BS, HEAD, SEQLEN, DIM, causal=True) 245 | -------------------------------------------------------------------------------- /flash_attention_cuda/flash_attn_triton.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py 2 | # 3 | # https://github.com/kyegomez/FlashAttention20Triton 4 | 5 | from torch import float32 6 | import torch 7 | import time 8 | import triton 9 | import triton.language as tl 10 | 11 | def flash_attn_triton(q, k, v, causal=True, sm_scale=1): 12 | # shape constraints 13 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 14 | assert Lq == Lk and Lk == Lv 15 | assert Lk in {16, 32, 64, 128} 16 | 17 | o = torch.empty_like(q) 18 | 19 | BLOCK_M = 128 20 | BLOCK_N = 64 21 | # NOTE: 对于flash attention 2, 外层循环的q可以并行处理, 因此每个thread需要计算正确的offset 22 | # 一个q, k, v的shape往往是(bs, head, seqlen, dim) 23 | # 对于(bs, head)中的每个元素都分配一个thread 24 | # 对于seqlen / BLOCK_M个的q分块, 每个分块再分配一个thread 25 | grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) 26 | # NOTE: 27 | # L.shape = (bs * head, seqlen) 28 | # L记录了所有的分母和mi(m_i + tl.math.log2(l_i)), 用于后续的backward 29 | L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) 30 | # 设置适当的wrap以提升性能 31 | num_warps = 4 if Lk <= 64 else 8 32 | _fwd_kernel[grid]( 33 | q, k, v, sm_scale, 34 | L, 35 | o, 36 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 37 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 38 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 39 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 40 | q.shape[0], q.shape[1], q.shape[2], 41 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, DIM=Lk, 42 | IS_CAUSAL=causal, 43 | num_warps=num_warps, 44 | num_stages=4) 45 | 46 | return o 47 | 48 | 49 | @triton.jit 50 | def _fwd_kernel( 51 | Q, K, V, sm_scale, 52 | # L记录了所有的分母和mi, 用于后续的backward 53 | L, 54 | O, 55 | stride_q_bs, stride_q_head, stride_q_seqlen, stride_q_dim, 56 | stride_k_bs, stride_k_head, stride_k_seqlen, stride_k_dim, 57 | stride_v_bs, stride_v_head, stride_v_seqlen, stride_v_dim, 58 | stride_o_bs, stride_o_head, stride_o_seqlen, stride_o_dim, 59 | BS, HEAD, SEQLEN, 60 | # BLOCK_M用于做Q的分块 61 | BLOCK_M: tl.constexpr, 62 | DIM: tl.constexpr, 63 | # BLOCK_N用于做K和V的分块 64 | BLOCK_N: tl.constexpr, 65 | IS_CAUSAL: tl.constexpr, 66 | ): 67 | # grid = (cdiv(seqlen, BLOCK_M), bs * head) 68 | # triton.language.program_id(axis) axis is The axis of the 3D launch grid 69 | # Q分块的起始地址 70 | start_m = tl.program_id(0) 71 | # 跳过(bs, head)的偏移 72 | off_bs_head = tl.program_id(1) 73 | 74 | # NOTE: 75 | # base = off_bs_head * stride_q_head找到正确的(bs, head)位置 76 | # strides: 步长, advance时直接使用步数, 会自动根据步长计算跳过的元素 77 | # offsets表示parent block (seqlen, dim)中怎么偏移来获取小块 78 | # block_shape=(BLOCK_M, DIM)表示parent block的shape 79 | # order表示用什么顺序读取存储来构造所需的shape 80 | qkv_base_offset = off_bs_head * stride_q_head 81 | Q_block_ptr = tl.make_block_ptr( 82 | # base offset to skip to the right (bs, head) 83 | base=Q + qkv_base_offset, 84 | # the shape of parent 85 | shape=(SEQLEN, DIM), 86 | strides=(stride_q_seqlen, stride_q_dim), 87 | # offset of the block inside of parent block 88 | offsets=(start_m * BLOCK_M, 0), 89 | block_shape=(BLOCK_M, DIM), 90 | order=(1, 0), 91 | ) 92 | K_block_ptr = tl.make_block_ptr( 93 | # base offset to skip to the right (bs, head) 94 | base=K + qkv_base_offset, 95 | # the shape of parent 96 | # NOTE: make_block_ptr读入时将K转置了 97 | shape=(DIM, SEQLEN), 98 | strides=(stride_k_dim, stride_k_seqlen), 99 | # 每个Q需要遍历整个的k和v 100 | offsets=(0, 0), 101 | # K根据BLOCK_N分块 102 | block_shape=(DIM, BLOCK_N), 103 | # 读入K的转置 104 | order=(0, 1), 105 | ) 106 | V_block_ptr = tl.make_block_ptr( 107 | # base offset to skip to the right (bs, head) 108 | base=V + qkv_base_offset, 109 | # the shape of parent 110 | shape=(SEQLEN, DIM), 111 | strides=(stride_k_seqlen, stride_v_dim), 112 | # 每个Q需要遍历整个的k和v 113 | offsets=(0, 0), 114 | # K根据BLOCK_N分块 115 | block_shape=(BLOCK_N, DIM), 116 | order=(1, 0), 117 | ) 118 | # initialize offsets 119 | # NOTE: BLOCK_M表示Q的分块大小, BLOCK_N表示k, v的分块大小 120 | off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 121 | off_n = tl.arange(0, BLOCK_N) 122 | # initialize pointers 123 | # NOTE: 一次处理一个(BLOCK_M, dim)的q, 而max和分母的sum都只需要一维, 即(BLOCK_M, 1) 124 | max = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') 125 | # 分母累加的sum, 每行的sum是一样的, 所以只需要一维然后广播即可 126 | denom = tl.zeros([BLOCK_M], dtype=tl.float32) 127 | out_buffer = tl.zeros([BLOCK_M, DIM], dtype=tl.float32) 128 | # NOTE: 129 | # scale sm_scale by log_2(e) and use 130 | # 2^x instead of exp in the loop because CSE and LICM 131 | # don't work as expected with `exp` in the loop 132 | # CSE(common subexpression elimination), LICM(loop invariant code motion)是编译器里的东西 133 | qk_scale = sm_scale * 1.44269504 134 | # load q: stay in SRAM throughout 135 | q = tl.load(Q_block_ptr) 136 | q = (q * qk_scale).to(tl.float16) 137 | # loop over k, v and update accumulator 138 | lo = 0 139 | # NOTE:: CAUSAL就是常说的不能看到后面的文本的自回归模型 140 | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else SEQLEN 141 | # NOTE: 142 | # 当前q和0..seqlen的kv算attention 143 | # 每次批处理BLOCK_N个k, v(即k, v以BLOCK_N分块) 144 | for start_n in range(lo, hi, BLOCK_N): 145 | k = tl.load(K_block_ptr) 146 | v = tl.load(V_block_ptr) 147 | 148 | # compute qk 149 | # NOTE: q.shape = (BLOCK_M, dim), k.shape(已转置) = (dim, BLOCK_N) 150 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 151 | if IS_CAUSAL: 152 | qk = tl.where(off_m[:, None] >= (start_n + off_n[None, :]), qk, float("-inf")) 153 | # NOTE: 执行矩阵乘法(matrix product), k在make_block_ptr时已经转置 154 | # qk init as zero 155 | qk += tl.dot(q, k) 156 | 157 | # compute scaling constant 158 | 159 | # NOTE: 160 | # max.shape = [BLOCK_M], aka [BLOCK_M, 1] 161 | # qk.shape = [BLOCK_M, BLOCK_N] 162 | # tl.max(block, axis) 163 | # tl.maximum(block, block) 164 | max_new = tl.maximum(max, tl.max(qk, 1)) 165 | # 保存exp的值, 节省exp操作 166 | alpha = tl.math.exp2(max - max_new) 167 | # NOTE: 168 | # nume = e^{x - max(x)} 169 | # max.shape = [BLOCK_M], max_new[:, None]扩展成[BLOCK_M, 1]来做广播操作 170 | nume = tl.math.exp2(qk - max_new[:, None]) 171 | # scale and update acc 172 | # NOTE: 利用广播来快速构建scale用于更新分母 173 | out_scale = denom * 0 + alpha 174 | # NOTE: 175 | # out_scale.shape = l_i.shape = [BLOCK_M] 176 | # out_scale[:, None]扩展成[BLOCK_M, 1]来做广播操作 177 | # out_buffer = old_out * scale来更新分子 178 | out_buffer *= out_scale[:, None] 179 | out_buffer += tl.dot(nume.to(tl.float16), v) 180 | # update max and denominator 181 | denom = denom * alpha + tl.sum(nume, 1) 182 | max = max_new 183 | # update k v pointer 184 | # NOTE: 计算下一个k, v的分块 185 | # 因为k已经转置(dim, seqlen), 所以算下一批seq的k时是增加k的第二个维度 186 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 187 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 188 | 189 | # write back l and m for backward 190 | # 最后统一更新output buffer, 除上完整的分母 191 | out_buffer = out_buffer / denom[:, None] 192 | # NOTE: 将分母和mi保存到L中, 用于后续的backward 193 | # L.shape = (bs * head, seqlen), 因为每一行的分母和mi是相同的 194 | # off_bs_head = bs * head 195 | l_ptr = L + off_bs_head * SEQLEN + off_m 196 | # write [BLOCK_M] of data to L 197 | tl.store(l_ptr, max + tl.math.log2(denom)) 198 | # write back O 199 | O_block_ptr = tl.make_block_ptr( 200 | base=O + qkv_base_offset, 201 | shape=(SEQLEN, DIM), 202 | strides=(stride_o_seqlen, stride_o_dim), 203 | offsets=(start_m * BLOCK_M, 0), 204 | block_shape=(BLOCK_M, DIM), 205 | order=(1, 0), 206 | ) 207 | tl.store(O_block_ptr, out_buffer.to(tl.float16)) 208 | 209 | def ref_attn(q, k, v, causal=True, sm_scale=1): 210 | SEQLEN = q.shape[-2] 211 | M = torch.tril(torch.ones((SEQLEN, SEQLEN), device="cuda")) 212 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 213 | if causal: 214 | p[:, :, M == 0] = float("-inf") 215 | p = torch.softmax(p.float(), dim=-1).half() 216 | ref_out = torch.matmul(p, v) 217 | return ref_out 218 | 219 | def causal_test(BS, HEAD, SEQLEN, DIM, causal): 220 | dtype = torch.float16 221 | torch.manual_seed(20) 222 | q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 223 | k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 224 | v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 225 | sm_scale = 0.5 226 | 227 | # reference implementation 228 | time_ref = time.time() 229 | ref_out = ref_attn(q, k, v, causal=causal, sm_scale=sm_scale) 230 | time_ref = time.time() - time_ref 231 | 232 | # triton implementation 233 | time_tri = time.time() 234 | tri_out = flash_attn_triton(q, k, v, causal=causal, sm_scale=sm_scale).half() 235 | time_tri = time.time() - time_tri 236 | 237 | # compare 238 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 239 | print("causal = {} ref time: {:.4f} ms, tri time: {:.4f}".format(causal, time_ref * 1000, time_tri * 1000)) 240 | 241 | def test_attention(): 242 | BS, HEAD, SEQLEN, DIM = 1, 2, 1024, 64 243 | causal_test(BS, HEAD, SEQLEN, DIM, causal=False) 244 | causal_test(BS, HEAD, SEQLEN, DIM, causal=True) 245 | 246 | test_attention() 247 | -------------------------------------------------------------------------------- /flash_attention_cuda/standalone_src/flash_attention_v2_standalone.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "helper.h" 6 | 7 | #define CUDA_CHECK(condition) \ 8 | do { \ 9 | cudaError_t error = condition; \ 10 | if (error != cudaSuccess) { \ 11 | printf("CUDA_CHECK error in line %d of file %s \ 12 | : %s \n", \ 13 | __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ 14 | exit(EXIT_FAILURE); \ 15 | } \ 16 | } while (0) 17 | 18 | // #define DEBUG 19 | 20 | #ifdef DEBUG 21 | #define DEBUG_BLOCK(expr) \ 22 | do { \ 23 | expr \ 24 | } while (0) 25 | #else 26 | #define DEBUG_BLOCK(...) \ 27 | do { \ 28 | } while (0) 29 | #endif 30 | 31 | 32 | // data type to test 33 | using FP = float; 34 | // BLOCK_M(Br, Brow), BLOCK_N(Bc, Bcol) can be determined at compile time 35 | // just like offical implementation which use a template kernel to do that 36 | // Block row size 37 | const int Br = 2; 38 | // Block column size 39 | const int Bc = 2; 40 | // seqlen 41 | const int input_seq = 4; 42 | // dim 43 | const int dim = 4; 44 | 45 | 46 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 47 | int M, int N, int K, int mBlock); 48 | __global__ void row_softmax(float *input, float *output, int n); 49 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 50 | int mBlock); 51 | 52 | __global__ void flash_attention_v2_kernel(FP *Q, FP* K, FP* V, FP* O, int seqlen, FP smScale); 53 | 54 | void flash_attention_v2_cuda(FP *Q, FP *K, FP *V, FP *O, int m, int n) { 55 | FP sm_scale = 1.f / sqrtf(static_cast(n)); 56 | int BS = 1; 57 | int HEAD = 1; 58 | int SEQLEN = m; 59 | int DIM = n; 60 | 61 | int Gc = 1; 62 | int Gr = (SEQLEN + Br - 1) / Br; 63 | 64 | // NOTE: each block process a range row of Q 65 | dim3 grid = dim3(Gc, Gr); 66 | // NOTE: each thread process a tile of Q 67 | dim3 block = dim3(Bc, Br); 68 | flash_attention_v2_kernel<<>>(Q, K, V, O, SEQLEN, sm_scale); 69 | 70 | DEBUG_BLOCK( 71 | printf("== v2: O ==\n"); 72 | print_device_matrix(O, SEQLEN, DIM); 73 | ); 74 | } 75 | 76 | __global__ void flash_attention_v2_kernel(FP *Q, FP* K, FP* V, FP* O, int seqlen, FP smScale) { 77 | // block size for K, V 78 | // group of row(seqlen) 79 | int groupSeq = (seqlen + Bc - 1) / Bc; 80 | // parallel process for V[Br, d] 81 | // group of column 82 | int groupTx = (dim + Bc - 1) / Bc; 83 | int groupTy = (dim + Br - 1) / Br; 84 | 85 | // load slice from global memory(HBM) 86 | __shared__ FP sQ[Br][dim]; 87 | __shared__ FP sK[Bc][dim]; 88 | __shared__ FP sV[Bc][dim]; 89 | // tmp o 90 | __shared__ FP sO[Br][dim]; 91 | __shared__ FP sQK[Br][Bc]; 92 | // e^{x - max} 93 | __shared__ FP sSafeE[Br][Bc]; 94 | // s stand for shared and local 95 | __shared__ FP sDenom[Br]; 96 | __shared__ FP sMax[Br]; 97 | 98 | // TODO: multihead 99 | 100 | // [0, Bc] 101 | int tx = threadIdx.x; 102 | // [0, Br] 103 | int ty = threadIdx.y; 104 | 105 | int row = ty + blockIdx.y * blockDim.y; 106 | 107 | if (row >= seqlen) { 108 | } 109 | // load q, o, max, denom from global memory to shared memory 110 | // Q[Br, dim] 111 | for (int i = 0; i < groupTx; i++) { 112 | sQ[ty][i * Bc + tx] = Q[row * dim + i * Bc + tx]; 113 | // NOTE:: accumulator zero init here 114 | sO[ty][i * Bc + tx] = 0; 115 | } 116 | 117 | sMax[ty] = -INFINITY; 118 | sDenom[ty] = 0; 119 | 120 | // load K, V block 121 | // Q[Br][dim] @ K[0..seqlen.step(Bc), dim] 122 | // compute partial sum of O[ty][dim] each iteration 123 | for (int j = 0; j < groupSeq; j++) { 124 | if ((j * Bc + tx) < seqlen) { 125 | // load k, v from global memory to shared memory 126 | // K[seqlen, dim], V[seqlen, dim] 127 | for (int i = 0; i < groupTy; i++) { 128 | // NOTE: 129 | // each thread.x copy a row of K to K.T 130 | // row0, t0: 131 | // row1, t1: 132 | // row2, t0: 133 | // row3, t2: 134 | sK[tx][i * Br + ty] = K[j * Bc * dim + tx * dim + i * Br + ty]; 135 | sV[tx][i * Br + ty] = V[j * Bc * dim + tx * dim + i * Br + ty]; 136 | } 137 | } 138 | 139 | // wait until g2s done 140 | __syncthreads(); 141 | 142 | // compute qk 143 | FP sum = 0.f; 144 | // result oriented: qk[y][x] from q[y] @ k[x] 145 | for (int i = 0; i < dim; i++) { 146 | sum += sQ[ty][i] * sK[tx][i]; 147 | } 148 | // sQK[Br, Bc] 149 | sQK[ty][tx] = sum * smScale; 150 | 151 | // wait until qk done 152 | __syncthreads(); 153 | 154 | // compute local max of each row of qk 155 | FP localMax = -INFINITY; 156 | for (int i = 0; i < Bc; i++) { 157 | localMax = max(localMax, sQK[ty][i]); 158 | } 159 | __syncthreads(); 160 | // compute the max of each row 161 | FP newMax = max(sMax[ty], localMax); 162 | 163 | // compute safe e(e^{x - max}) of each qk element 164 | sSafeE[ty][tx] = exp(sQK[ty][tx] - newMax); 165 | __syncthreads(); 166 | 167 | // accumulate local denom of each row of qk with local max 168 | FP localDenom = 0.f; 169 | for (int i = 0; i < Bc; i++) { 170 | localDenom += sSafeE[ty][i]; 171 | } 172 | __syncthreads(); 173 | 174 | // rescale history result 175 | FP rescaleOld = exp(sMax[ty] - newMax); 176 | // rescale denom 177 | FP newDenom = sDenom[ty] * rescaleOld + localDenom; 178 | 179 | // NOTE: 180 | // QK[Br, Bc] @ V[Bc, d] = O[Br, d] 181 | // tx in [0, Bc], ty in [0, Br] 182 | // slice-Bc and each O[ty, group.x] as accumulator 183 | for (int i = 0; i < groupTx; i++) { 184 | // NOTE: rescale old_o(numerator only for now) once: old_nume * rescale 185 | sO[ty][i * Bc + tx] = (sO[ty][i * Bc + tx] * rescaleOld); 186 | for (int k = 0; k < Bc; k++) { 187 | // NOTE: 188 | // accumulate numerator 189 | // new_nume = old_nume' + local_nume (Softmax(QK)@V) 190 | sO[ty][i * Bc + tx] += sSafeE[ty][k] * sV[k][i * Bc + tx]; 191 | } 192 | } 193 | 194 | // update global max and denom 195 | sMax[ty] = newMax; 196 | sDenom[ty] = newDenom; 197 | __syncthreads(); 198 | } 199 | 200 | // rescale O in the end 201 | for (int i = 0; i < groupTx; i++) { 202 | // copy sO[row, dim] to gO[row, dim] 203 | O[row * dim + i * Bc + tx] = sO[ty][i * Bc + tx] / sDenom[ty]; 204 | } 205 | } 206 | 207 | void self_attention_cuda(float *Q, float *K, float *V, float *O, int m, int n) { 208 | int mBlock = 2; 209 | assert(m % mBlock == 0 && "mBlock should align"); 210 | 211 | float sm_scale = 1.f / sqrtf(static_cast(n)); 212 | float *sm_o; 213 | cudaMalloc((void **)&sm_o, sizeof(float) * m * m); 214 | 215 | dim3 qk_block(m / mBlock, 1, 1); 216 | naive_nrow_gemm<<<1, qk_block>>>(Q, K, sm_o, sm_scale, 0, m, m, n, mBlock); 217 | cudaDeviceSynchronize(); 218 | DEBUG_BLOCK( 219 | CUDA_CHECK(cudaGetLastError()); 220 | printf("== naive QK ==\n"); 221 | print_device_matrix(sm_o, m, m); 222 | ); 223 | 224 | // QK[M, M] 225 | dim3 sm_block(m, 1, 1); 226 | row_softmax<<<1, sm_block>>>(sm_o, sm_o, m); 227 | cudaDeviceSynchronize(); 228 | DEBUG_BLOCK( 229 | CUDA_CHECK(cudaGetLastError()); 230 | printf("== naive softmax(QK) ==\n"); 231 | print_device_matrix(sm_o, m, m); 232 | ); 233 | 234 | // QK[M, M] @ V[M, N] 235 | dim3 qkv_block(m / mBlock, 1, 1); 236 | naive_pv<<<1, qkv_block>>>(sm_o, V, O, m, n, mBlock); 237 | cudaDeviceSynchronize(); 238 | DEBUG_BLOCK( 239 | CUDA_CHECK(cudaGetLastError()); 240 | printf("== naive softmax(QK)V ==\n"); 241 | print_device_matrix(O, m, n); 242 | ); 243 | 244 | cudaFree(sm_o); 245 | } 246 | 247 | // naive gemm implement with slice-k 248 | // perform C = aA@B + bC 249 | // A[M, K] x B[K, N] = C[M, N] 250 | // each thread process mblock rows of A 251 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 252 | int M, int N, int K, int mBlock) { 253 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 254 | 255 | // each thread process a range of rows 256 | idx *= mBlock; 257 | 258 | // A[mBlock, K] x B[N, K].T = C[mBlock, N] 259 | for (int i = idx; i < idx + mBlock; i++) { 260 | for (int j = 0; j < N; j++) { 261 | float sum = 0.f; 262 | for (int k = 0; k < K; k++) { 263 | sum += A[i * K + k] * B[j * K + k]; 264 | } 265 | // C[M, N] 266 | // C = aA@B + bC 267 | C[i * N + j] = a * sum + b * C[i * N + j]; 268 | } 269 | } 270 | } 271 | 272 | // perform QK[M, M] @ V[M, N] 273 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 274 | int mBlock) { 275 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 276 | 277 | // each thread process a range of rows 278 | idx *= mBlock; 279 | 280 | int K = M; 281 | // P[mBlock, M] x V[M, N] = O[mBlock, N] 282 | for (int i = idx; i < idx + mBlock; i++) { 283 | for (int j = 0; j < N; j++) { 284 | float sum = 0.f; 285 | for (int k = 0; k < K; k++) { 286 | sum += P[i * K + k] * V[k * N + j]; 287 | } 288 | // C[M, N] 289 | O[i * N + j] = sum; 290 | } 291 | } 292 | } 293 | 294 | // each thread process one row of softmax 295 | __global__ void row_softmax(float *input, float *output, int n) { 296 | // assume id will not exceed row number of input 297 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 298 | 299 | float max = -INFINITY; 300 | float sum = 0.f; 301 | 302 | // Find max 303 | for (int i = 0; i < n; i++) { 304 | if (input[idx * n + i] > max) { 305 | max = input[idx * n + i]; 306 | } 307 | } 308 | 309 | // Compute numerator and denominator 310 | for (int i = 0; i < n; i++) { 311 | output[idx * n + i] = exp(input[idx * n + i] - max); 312 | sum += output[idx * n + i]; 313 | } 314 | 315 | // Compute softmax 316 | for (int i = 0; i < n; i++) { 317 | output[idx * n + i] /= sum; 318 | } 319 | } 320 | 321 | void test_attention() { 322 | // seqlen 323 | int m = input_seq; 324 | // dim 325 | int n = dim; 326 | 327 | // Host pointer 328 | float *h_K = new float[m * n]; 329 | float *h_Q = new float[m * n]; 330 | float *h_V = new float[m * n]; 331 | float *h_O = new float[m * n]; 332 | float *h_O2 = new float[m * n]; 333 | 334 | // 初始化 K, Q, V 335 | for (int i = 0; i < m * n; ++i) { 336 | h_K[i] = static_cast(rand()) / RAND_MAX; 337 | h_Q[i] = static_cast(rand()) / RAND_MAX; 338 | h_V[i] = static_cast(rand()) / RAND_MAX; 339 | 340 | DEBUG_BLOCK( 341 | h_K[i] = static_cast(i); 342 | h_Q[i] = static_cast(i); 343 | h_V[i] = static_cast(i); 344 | ); 345 | } 346 | 347 | DEBUG_BLOCK( 348 | printf("== K ==\n"); 349 | print_host_matrix(h_K, m, n); 350 | ); 351 | 352 | float *d_K, *d_Q, *d_V, *d_O, *d_O2; 353 | // Malloc device memory 354 | cudaMalloc((void **)&d_K, sizeof(float) * m * n); 355 | cudaMalloc((void **)&d_Q, sizeof(float) * m * n); 356 | cudaMalloc((void **)&d_V, sizeof(float) * m * n); 357 | cudaMalloc((void **)&d_O, sizeof(float) * m * n); 358 | cudaMalloc((void **)&d_O2, sizeof(float) * m * n); 359 | 360 | // Copy data from host to device 361 | cudaMemcpy(d_K, h_K, sizeof(float) * m * n, cudaMemcpyHostToDevice); 362 | cudaMemcpy(d_Q, h_Q, sizeof(float) * m * n, cudaMemcpyHostToDevice); 363 | cudaMemcpy(d_V, h_V, sizeof(float) * m * n, cudaMemcpyHostToDevice); 364 | 365 | cudaEvent_t start, stop; 366 | cudaEventCreate(&start); 367 | cudaEventCreate(&stop); 368 | cudaEventRecord(start, 0); 369 | 370 | // Run test 371 | for (int i = 0; i < 1; i++) { 372 | // Launch kernel 373 | self_attention_cuda(d_Q, d_K, d_V, d_O, m, n); 374 | 375 | CUDA_CHECK(cudaGetLastError()); 376 | } 377 | 378 | // test flash attention 2 379 | for (int i = 0; i < 1; i++) { 380 | flash_attention_v2_cuda(d_Q, d_K, d_V, d_O2, m, n); 381 | CUDA_CHECK(cudaGetLastError()); 382 | } 383 | 384 | cudaEventRecord(stop, 0); 385 | cudaEventSynchronize(stop); 386 | float milliseconds = 0; 387 | cudaEventElapsedTime(&milliseconds, start, stop); 388 | // printf("Time for kernel execution: %.3f ms \n", milliseconds / 100); 389 | cudaEventDestroy(start); 390 | cudaEventDestroy(stop); 391 | 392 | // Result back to host 393 | cudaMemcpy(h_O, d_O, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 394 | cudaMemcpy(h_O2, d_O2, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 395 | 396 | assert(all_close(h_O, h_O2, m, n) && "flash attention v2 not equal to naive"); 397 | 398 | cudaFree(d_K); 399 | cudaFree(d_Q); 400 | cudaFree(d_V); 401 | cudaFree(d_O); 402 | cudaFree(d_O2); 403 | free(h_Q); 404 | free(h_K); 405 | free(h_V); 406 | free(h_O); 407 | free(h_O2); 408 | } 409 | 410 | int main() { 411 | int epoch = 1000; 412 | DEBUG_BLOCK( epoch = 1; ); 413 | for (int i = 0; i < epoch; i++) { 414 | test_attention(); 415 | } 416 | 417 | return 0; 418 | } 419 | -------------------------------------------------------------------------------- /flash_attention_cuda/csrc/flash_attention.cu: -------------------------------------------------------------------------------- 1 | #include "attention_api.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "static_switch.h" 11 | 12 | 13 | // BLOCK_M(Br or Brow), BLOCK_N(Bc or Bcol) can be determined at compile time 14 | // just like offical implementation which use a template kernel to do that 15 | // Dim is enumberated at runtime for all supported dim 16 | template 17 | __global__ void flash_attention_v2_kernel(Ty *Q, Ty *K, Ty *V, Ty *O, 18 | int seqlen, int stride_head, Ty smScale) { 19 | // block size for K, V 20 | // group of row(seqlen) 21 | int groupSeq = (seqlen + kBc - 1) / kBc; 22 | // parallel process for V[Br, d] 23 | // group of column 24 | int groupTx = (kDim + kBc - 1) / kBc; 25 | int groupTy = (kDim + kBr - 1) / kBr; 26 | 27 | // load slice from global memory(HBM) 28 | __shared__ Ty sQ[kBr][kDim]; 29 | __shared__ Ty sK[kBc][kDim]; 30 | __shared__ Ty sV[kBc][kDim]; 31 | // tmp o 32 | __shared__ Ty sO[kBr][kDim]; 33 | __shared__ Ty sQK[kBr][kBc]; 34 | // e^{x - max} 35 | __shared__ Ty sSafeE[kBr][kBc]; 36 | // s stand for shared and local 37 | __shared__ Ty sDenom[kBr]; 38 | __shared__ Ty sMax[kBr]; 39 | 40 | // [0, Bc] 41 | int tx = threadIdx.x; 42 | // [0, Br] 43 | int ty = threadIdx.y; 44 | 45 | // each thread in the same blockIdx.x in the same (bs, head), 46 | // which shared memory and process a QKV 47 | int base_offset = blockIdx.x * stride_head; 48 | int row = ty + blockIdx.y * blockDim.y; 49 | 50 | // TODO: need a way to round up seqlen 51 | if (row >= seqlen) { 52 | return; 53 | } 54 | 55 | Q += base_offset; 56 | K += base_offset; 57 | V += base_offset; 58 | O += base_offset; 59 | 60 | // load q, o, max, denom from global memory to shared memory 61 | // Q[Br, dim] 62 | for (int i = 0; i < groupTx; i++) { 63 | sQ[ty][i * kBc + tx] = Q[row * kDim + i * kBc + tx]; 64 | // NOTE:: accumulator zero init here 65 | sO[ty][i * kBc + tx] = 0; 66 | } 67 | 68 | sMax[ty] = -INFINITY; 69 | sDenom[ty] = 0; 70 | 71 | // load K, V block 72 | // Q[Br][dim] @ K[0..seqlen.step(Bc), dim] 73 | // compute partial sum of O[ty][dim] each iteration 74 | for (int j = 0; j < groupSeq; j++) { 75 | if ((j * kBc + tx) < seqlen) { 76 | // load k, v from global memory to shared memory 77 | // K[seqlen, dim], V[seqlen, dim] 78 | for (int i = 0; i < groupTy; i++) { 79 | // NOTE: 80 | // each thread.x copy a row of K to K.T 81 | // row0, t0: 82 | // row1, t1: 83 | // row2, t0: 84 | // row3, t2: 85 | sK[tx][i * kBr + ty] = K[j * kBc * kDim + tx * kDim + i * kBr + ty]; 86 | sV[tx][i * kBr + ty] = V[j * kBc * kDim + tx * kDim + i * kBr + ty]; 87 | } 88 | } 89 | 90 | // wait until g2s done 91 | __syncthreads(); 92 | 93 | // compute qk 94 | Ty sum = 0.f; 95 | // result oriented: qk[y][x] from q[y] @ k[x] 96 | for (int i = 0; i < kDim; i++) { 97 | sum += sQ[ty][i] * sK[tx][i]; 98 | } 99 | // sQK[Br, Bc] 100 | sQK[ty][tx] = sum * smScale; 101 | 102 | // wait until qk done 103 | __syncthreads(); 104 | 105 | // compute local max of each row of qk 106 | Ty localMax = -INFINITY; 107 | for (int i = 0; i < kBc; i++) { 108 | localMax = max(localMax, sQK[ty][i]); 109 | } 110 | __syncthreads(); 111 | // compute the max of each row 112 | Ty newMax = max(sMax[ty], localMax); 113 | 114 | // compute safe e(e^{x - max}) of each qk element 115 | sSafeE[ty][tx] = exp(sQK[ty][tx] - newMax); 116 | __syncthreads(); 117 | 118 | // accumulate local denom of each row of qk with local max 119 | Ty localDenom = 0.f; 120 | for (int i = 0; i < kBc; i++) { 121 | localDenom += sSafeE[ty][i]; 122 | } 123 | __syncthreads(); 124 | 125 | // rescale history result 126 | Ty rescaleOld = exp(sMax[ty] - newMax); 127 | // rescale denom 128 | Ty newDenom = sDenom[ty] * rescaleOld + localDenom; 129 | 130 | // NOTE: 131 | // QK[Br, Bc] @ V[Bc, d] = O[Br, d] 132 | // tx in [0, Bc], ty in [0, Br] 133 | // slice-Bc and each O[ty, group.x] as accumulator 134 | for (int i = 0; i < groupTx; i++) { 135 | // NOTE: rescale old_o(numerator only for now) once: old_nume * rescale 136 | sO[ty][i * kBc + tx] = (sO[ty][i * kBc + tx] * rescaleOld); 137 | for (int k = 0; k < kBc; k++) { 138 | // NOTE: 139 | // accumulate numerator 140 | // new_nume = old_nume' + local_nume (Softmax(QK)@V) 141 | sO[ty][i * kBc + tx] += sSafeE[ty][k] * sV[k][i * kBc + tx]; 142 | } 143 | } 144 | 145 | // update global max and denom 146 | sMax[ty] = newMax; 147 | sDenom[ty] = newDenom; 148 | __syncthreads(); 149 | } 150 | 151 | // rescale O in the end 152 | for (int i = 0; i < groupTx; i++) { 153 | // copy sO[row, dim] to gO[row, dim] 154 | O[row * kDim + i * kBc + tx] = sO[ty][i * kBc + tx] / sDenom[ty]; 155 | } 156 | } 157 | 158 | template 159 | __global__ void flash_attention_v1_kernel(Ty *Q, Ty *K, Ty *V, Ty *O, Ty *gMax, 160 | Ty *gDenom, int seqlen, int stride_head, Ty smScale) { 161 | // block size for K, V 162 | // group of row(seqlen) 163 | int groupSeq = (seqlen + kBc - 1) / kBc; 164 | // parallel process for V[Br, d] 165 | // group of column 166 | int groupTx = (kDim + kBc - 1) / kBc; 167 | int groupTy = (kDim + kBr - 1) / kBr; 168 | 169 | // load slice from global memory(HBM) 170 | __shared__ Ty sQ[kBr][kDim]; 171 | __shared__ Ty sK[kBc][kDim]; 172 | __shared__ Ty sV[kBc][kDim]; 173 | __shared__ Ty sO[kBr][kDim]; 174 | __shared__ Ty sQK[kBr][kBc]; 175 | 176 | __shared__ Ty sNewO[kBr][kDim]; 177 | // e^{x - max} 178 | __shared__ Ty sSafeE[kBr][kBc]; 179 | // s stand for shared and local 180 | __shared__ Ty sDenom[kBr]; 181 | __shared__ Ty sMax[kBr]; 182 | 183 | // [0, Bc] 184 | int tx = threadIdx.x; 185 | // [0, Br] 186 | int ty = threadIdx.y; 187 | 188 | int row = ty + blockIdx.y * blockDim.y; 189 | int base_offset = blockIdx.x; 190 | 191 | if (row >= seqlen) { 192 | return; 193 | } 194 | 195 | Q += base_offset * stride_head; 196 | K += base_offset * stride_head; 197 | V += base_offset * stride_head; 198 | O += base_offset * stride_head; 199 | gMax += base_offset * seqlen; 200 | gDenom += base_offset * seqlen; 201 | 202 | for (int j = 0; j < groupSeq; j++) { 203 | if ((j * kBc + tx) < seqlen) { 204 | // load k, v from global memory to shared memory 205 | // K[seqlen, dim], V[seqlen, dim] 206 | for (int i = 0; i < groupTy; i++) { 207 | // each thread.x copy a row of K to K.T 208 | // row0, t0: 209 | // row1, t1: 210 | // row2, t0: 211 | // row3, t2: 212 | sK[tx][i * kBr + ty] = K[j * kBc * kDim + tx * kDim + i * kBr + ty]; 213 | sV[tx][i * kBr + ty] = V[j * kBc * kDim + tx * kDim + i * kBr + ty]; 214 | } 215 | } 216 | 217 | if (row < seqlen) { 218 | // load q, o, max, denom from global memory to shared memory 219 | // Q[seqlen, dim] 220 | for (int i = 0; i < groupTx; i++) { 221 | sQ[ty][i * kBc + tx] = Q[row * kDim + i * kBc + tx]; 222 | sO[ty][i * kBc + tx] = O[row * kDim + i * kBc + tx]; 223 | } 224 | 225 | // NOTE: the drawback of flash attention 1 is here that it will load O, 226 | // max, denom from global memory to shared memory many time 227 | sMax[ty] = gMax[row]; 228 | sDenom[ty] = gDenom[row]; 229 | } 230 | 231 | // wait until g2s done 232 | __syncthreads(); 233 | 234 | // compute qk 235 | Ty sum = 0.f; 236 | // result oriented: qk[y][x] from q[y] @ k[x] 237 | for (int i = 0; i < kDim; i++) { 238 | sum += sQ[ty][i] * sK[tx][i]; 239 | } 240 | // sQK[Br, Bc] 241 | sQK[ty][tx] = sum * smScale; 242 | 243 | // wait until qk done 244 | __syncthreads(); 245 | 246 | // compute local max of each row of qk 247 | Ty localMax = -INFINITY; 248 | for (int i = 0; i < kBc; i++) { 249 | localMax = max(localMax, sQK[ty][i]); 250 | } 251 | __syncthreads(); 252 | 253 | // compute safe e(e^{x - max}) of each qk element 254 | sSafeE[ty][tx] = exp(sQK[ty][tx] - localMax); 255 | __syncthreads(); 256 | 257 | // accumulate local denom of each row of qk with local max 258 | Ty localDenom = 0.f; 259 | for (int i = 0; i < kBc; i++) { 260 | localDenom += sSafeE[ty][i]; 261 | } 262 | __syncthreads(); 263 | 264 | // NOTE: this is a pure flash attention 1 implementation with many redundant 265 | // mul update global max of each row 266 | Ty newMax = max(sMax[ty], localMax); 267 | // rescale history result 268 | Ty rescaleOld = exp(sMax[ty] - newMax); 269 | // rescale result just computed above: sSafeE, localDenom 270 | Ty rescaleCur = exp(localMax - newMax); 271 | Ty newDenom = sDenom[ty] * rescaleOld + localDenom * rescaleCur; 272 | 273 | // clean each row of of sNewO 274 | for (int i = 0; i < groupTx; i++) { 275 | sNewO[ty][i * kBc + tx] = 0; 276 | } 277 | 278 | // NOTE: 279 | // QK[Br, Bc] @ V[Bc, d] = O[Br, d] 280 | // tx in [0, Bc], ty in [0, Br] 281 | // slice-Bc and each O[ty, group.x] as accumulator 282 | for (int k = 0; k < kBc; k++) { 283 | for (int i = 0; i < groupTx; i++) { 284 | // rescale numerator 285 | sNewO[ty][i * kBc + tx] += 286 | sSafeE[ty][k] * rescaleCur * sV[k][i * kBc + tx]; 287 | } 288 | } 289 | 290 | // NOTE: rescale output 291 | // old_nume = old_o * old_denom 292 | // new_o = (old_nume + new_nume) / new_denom 293 | for (int i = 0; i < groupTx; i++) { 294 | sNewO[ty][i * kBc + tx] = (/* new_nume */ sNewO[ty][i * kBc + tx] + 295 | /* old_o */ sO[ty][i * kBc + tx] * rescaleOld * 296 | /* old_denom */ sDenom[ty]) / 297 | newDenom; 298 | } 299 | 300 | __syncthreads(); 301 | 302 | // update global o 303 | if (row < seqlen) { 304 | for (int i = 0; i < groupTx; i++) { 305 | // copy sO[row, dim] to gO[row, dim] 306 | O[row * kDim + i * kBc + tx] = sNewO[ty][i * kBc + tx]; 307 | } 308 | } 309 | 310 | // update global max and denom 311 | gMax[row] = newMax; 312 | gDenom[row] = newDenom; 313 | __syncthreads(); 314 | } 315 | } 316 | 317 | torch::Tensor flash_attention_v1_cuda(torch::Tensor q, torch::Tensor k, 318 | torch::Tensor v) { 319 | CHECK_INPUT(q); 320 | CHECK_INPUT(k); 321 | CHECK_INPUT(v); 322 | 323 | // batch size 324 | int bs = q.size(0); 325 | // head number 326 | int head = q.size(1); 327 | // seqlen 328 | int seqlen = q.size(2); 329 | // dim 330 | int dim = q.size(3); 331 | float sm_scale = 1.f / sqrtf(static_cast(dim)); 332 | // offset 1 in head dim should skip seqlen * dim elements 333 | int stride_head = seqlen * dim; 334 | 335 | auto out = torch::zeros_like(q); 336 | 337 | // Create intermediate date for the new shape of max, denom 338 | torch::TensorOptions options = q.options(); 339 | std::vector shape = {bs, head, seqlen}; 340 | torch::Tensor gMax = torch::empty(shape, options); 341 | torch::fill(gMax, -INFINITY); 342 | torch::Tensor gDenom = torch::zeros(shape, options); 343 | 344 | const int Br = 2; 345 | const int Bc = 2; 346 | int Gc = bs * head; 347 | int Gr = (seqlen + Br - 1) / Br; 348 | 349 | // NOTE: each thread process a tile of Q 350 | dim3 grid = dim3(Gc, Gr); 351 | // NOTE: each block process a range row of Q 352 | dim3 block = dim3(Bc, Br); 353 | 354 | // NOTE: AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) 355 | // We need a way of determining at runtime what type a tensor is and then 356 | // selectively call functions with the corresponding correct type signature. 357 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 358 | q.scalar_type(), "flash_attn_v1", ([&] { 359 | FWD_HEADDIM_SWITCH(dim, [&]{ 360 | flash_attention_v1_kernel 361 | <<>>(q.data_ptr(), k.data_ptr(), 362 | v.data_ptr(), out.data_ptr(), 363 | gMax.data_ptr(), 364 | gDenom.data_ptr(), seqlen, stride_head, sm_scale); 365 | }); 366 | })); 367 | 368 | // Wait until kernel finish. 369 | cudaDeviceSynchronize(); 370 | CUDA_ERROR_CHECK(cudaGetLastError()); 371 | 372 | return out; 373 | } 374 | 375 | torch::Tensor flash_attention_v2_cuda(torch::Tensor q, torch::Tensor k, 376 | torch::Tensor v) { 377 | CHECK_INPUT(q); 378 | CHECK_INPUT(k); 379 | CHECK_INPUT(v); 380 | 381 | // batch size 382 | int bs = q.size(0); 383 | // head number 384 | int head = q.size(1); 385 | // seqlen 386 | int seqlen = q.size(2); 387 | // dim 388 | int dim = q.size(3); 389 | float sm_scale = 1.f / sqrtf(static_cast(dim)); 390 | // offset 1 in head dim should skip seqlen * dim elements 391 | int stride_head = seqlen * dim; 392 | 393 | auto out = torch::zeros_like(q); 394 | 395 | const int Br = 4; 396 | const int Bc = 4; 397 | // grid.x indicate the base offset 398 | int Gc = bs * head; 399 | // grid.y indicate the group of row 400 | int Gr = (seqlen + Br - 1) / Br; 401 | assert(dim % Bc == 0 && seqlen % Br == 0); 402 | 403 | // NOTE: each block process a range row of Q 404 | dim3 grid = dim3(Gc, Gr); 405 | // NOTE: each thread process a tile of Q 406 | dim3 block = dim3(Bc, Br); 407 | 408 | // NOTE: AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) 409 | // We need a way of determining at runtime what type a tensor is and then 410 | // selectively call functions with the corresponding correct type signature. 411 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(q.scalar_type(), "flash_attn_v2", ([&] { 412 | FWD_HEADDIM_SWITCH(dim, [&]{ 413 | flash_attention_v2_kernel<<>>( 414 | q.data_ptr(), k.data_ptr(), 415 | v.data_ptr(), out.data_ptr(), seqlen, stride_head, sm_scale); 416 | }); 417 | })); 418 | 419 | // Wait until kernel finish. 420 | cudaDeviceSynchronize(); 421 | CUDA_ERROR_CHECK(cudaGetLastError()); 422 | 423 | return out; 424 | } 425 | 426 | 427 | -------------------------------------------------------------------------------- /flash_attention_cuda/standalone_src/flash_attention_v1_standalone.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CUDA_CHECK(condition) \ 7 | do { \ 8 | cudaError_t error = condition; \ 9 | if (error != cudaSuccess) { \ 10 | printf("CUDA_CHECK error in line %d of file %s \ 11 | : %s \n", \ 12 | __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ 13 | exit(EXIT_FAILURE); \ 14 | } \ 15 | } while (0) 16 | 17 | #define DEBUG 18 | 19 | #ifdef DEBUG 20 | #define DEBUG_BLOCK(expr) \ 21 | do { \ 22 | expr \ 23 | } while (0) 24 | #else 25 | #define DEBUG_BLOCK(...) \ 26 | do { \ 27 | } while (0) 28 | #endif 29 | 30 | 31 | // data type to test 32 | using FP = float; 33 | // BLOCK_M(Br, Brow), BLOCK_N(Bc, Bcol) can be determined at compile time 34 | // just like offical implementation which use a template kernel to do that 35 | // Block row size 36 | const int Br = 2; 37 | // Block column size 38 | const int Bc = 2; 39 | // seqlen 40 | const int input_seq = 4; 41 | // dim 42 | const int dim = 4; 43 | 44 | 45 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 46 | int M, int N, int K, int mBlock); 47 | __global__ void row_softmax(float *input, float *output, int n); 48 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 49 | int mBlock); 50 | 51 | __global__ void flash_attention_v1_kernel(FP *Q, FP* K, FP* V, FP* O, FP* gMAX, FP* gDenom, int seqlen, FP smScale); 52 | void print_host_matrix(float *matrix, int m, int n); 53 | void print_device_matrix(float *matrix, int m, int n); 54 | 55 | void flash_attention_v1_cuda(FP *Q, FP *K, FP *V, FP *O, int m, int n) { 56 | FP *dev_max, *dev_denom, *host_max, *host_denom; 57 | // qk buffer 58 | FP *QK; 59 | 60 | FP sm_scale = 1.f / sqrtf(static_cast(n)); 61 | int BS = 1; 62 | int HEAD = 1; 63 | int SEQLEN = m; 64 | int DIM = n; 65 | 66 | host_max = new FP[SEQLEN]; 67 | host_denom = new FP[SEQLEN]; 68 | for (int i = 0; i < SEQLEN; i++) { 69 | host_max[i] = -INFINITY; 70 | host_denom[i] = 0; 71 | } 72 | 73 | CUDA_CHECK(cudaMalloc((void **)&dev_max, sizeof(FP) * SEQLEN * DIM)); 74 | CUDA_CHECK(cudaMalloc((void **)&dev_denom, sizeof(FP) * SEQLEN * DIM)); 75 | CUDA_CHECK(cudaMalloc((void **)&QK, sizeof(FP) * SEQLEN * SEQLEN)); 76 | CUDA_CHECK(cudaMemcpy(dev_max, host_max, sizeof(FP) * SEQLEN * DIM, cudaMemcpyHostToDevice)); 77 | CUDA_CHECK(cudaMemcpy(dev_denom, host_denom, sizeof(FP) * SEQLEN * DIM, cudaMemcpyHostToDevice)); 78 | 79 | 80 | int Gc = 1; 81 | int Gr = (SEQLEN + Br - 1) / Br; 82 | 83 | // NOTE: each block process a range row of Q 84 | dim3 grid = dim3(Gc, Gr); 85 | // NOTE: each thread process a tile of Q 86 | dim3 block = dim3(Bc, Br); 87 | flash_attention_v1_kernel<<>>(Q, K, V, O, dev_max, dev_denom, SEQLEN, sm_scale); 88 | 89 | printf("== V1: O ==\n"); 90 | print_device_matrix(O, SEQLEN, DIM); 91 | 92 | cudaFree(QK); 93 | cudaFree(dev_max); 94 | cudaFree(dev_denom); 95 | } 96 | 97 | __global__ void flash_attention_v1_kernel(FP *Q, FP* K, FP* V, FP* O, FP* gMAX, FP* gDenom, int seqlen, FP smScale) { 98 | // block size for K, V 99 | // group of row(seqlen) 100 | int groupSeq = (seqlen + Bc - 1) / Bc; 101 | // parallel process for V[Br, d] 102 | // group of column 103 | int groupTx = (dim + Bc - 1) / Bc; 104 | int groupTy = (dim + Br - 1) / Br; 105 | 106 | // load slice from global memory(HBM) 107 | __shared__ FP sQ[Br][dim]; 108 | __shared__ FP sK[Bc][dim]; 109 | __shared__ FP sV[Bc][dim]; 110 | __shared__ FP sO[Br][dim]; 111 | __shared__ FP sQK[Br][Bc]; 112 | 113 | __shared__ FP sNewO[Br][dim]; 114 | // e^{x - max} 115 | __shared__ FP sSafeE[Br][Bc]; 116 | // s stand for shared and local 117 | __shared__ FP sDenom[Br]; 118 | __shared__ FP sMax[Br]; 119 | 120 | // TODO: multihead 121 | 122 | // [0, Bc] 123 | int tx = threadIdx.x; 124 | // [0, Br] 125 | int ty = threadIdx.y; 126 | 127 | int row = ty + blockIdx.y * blockDim.y; 128 | for (int j = 0; j < groupSeq; j++) { 129 | if ((j * Bc + tx) < seqlen) { 130 | // load k, v from global memory to shared memory 131 | // K[seqlen, dim], V[seqlen, dim] 132 | for (int i = 0; i < groupTy; i++) { 133 | // each thread.x copy a row of K to K.T 134 | // row0, t0: 135 | // row1, t1: 136 | // row2, t0: 137 | // row3, t2: 138 | sK[tx][i * Br + ty] = K[j * Bc * dim + tx * dim + i * Br + ty]; 139 | sV[tx][i * Br + ty] = V[j * Bc * dim + tx * dim + i * Br + ty]; 140 | } 141 | } 142 | 143 | if (row < seqlen) { 144 | // load q, o, max, denom from global memory to shared memory 145 | // Q[seqlen, dim] 146 | for (int i = 0; i < groupTx; i++) { 147 | sQ[ty][i * Bc + tx] = Q[row * dim + i * Bc + tx]; 148 | sO[ty][i * Bc + tx] = O[row * dim + i * Bc + tx]; 149 | } 150 | 151 | // NOTE: the drawback of flash attention 1 is here that it will load O, max, denom from global memory to shared memory many time 152 | sMax[ty] = gMAX[row]; 153 | sDenom[ty] = gDenom[row]; 154 | } 155 | 156 | // wait until g2s done 157 | __syncthreads(); 158 | 159 | // compute qk 160 | FP sum = 0.f; 161 | // result oriented: qk[y][x] from q[y] @ k[x] 162 | for (int i = 0; i < dim; i++) { 163 | sum += sQ[ty][i] * sK[tx][i]; 164 | } 165 | // sQK[Br, Bc] 166 | sQK[ty][tx] = sum * smScale; 167 | 168 | // wait until qk done 169 | __syncthreads(); 170 | 171 | // compute local max of each row of qk 172 | FP localMax = -INFINITY; 173 | for (int i = 0; i < Bc; i++) { 174 | localMax = max(localMax, sQK[ty][i]); 175 | } 176 | __syncthreads(); 177 | 178 | // compute safe e(e^{x - max}) of each qk element 179 | sSafeE[ty][tx] = exp(sQK[ty][tx] - localMax); 180 | __syncthreads(); 181 | 182 | // accumulate local denom of each row of qk with local max 183 | FP localDenom = 0.f; 184 | for (int i = 0; i < Bc; i++) { 185 | localDenom += sSafeE[ty][i]; 186 | } 187 | __syncthreads(); 188 | 189 | // NOTE: this is a pure flash attention 1 implementation with many redundant mul 190 | // update global max of each row 191 | FP newMax = max(sMax[ty], localMax); 192 | // rescale history result 193 | FP rescaleOld = exp(sMax[ty] - newMax); 194 | // rescale result just computed above: sSafeE, localDenom 195 | FP rescaleCur = exp(localMax - newMax); 196 | FP newDenom = sDenom[ty] * rescaleOld + localDenom * rescaleCur; 197 | 198 | // clean each row of of sNewO 199 | for (int i = 0; i < groupTx; i++) { 200 | sNewO[ty][i * Bc + tx] = 0; 201 | } 202 | 203 | // NOTE: 204 | // QK[Br, Bc] @ V[Bc, d] = O[Br, d] 205 | // tx in [0, Bc], ty in [0, Br] 206 | // slice-Bc and each O[ty, group.x] as accumulator 207 | for (int k = 0; k < Bc; k++) { 208 | for (int i = 0; i < groupTx; i++) { 209 | // rescale numerator 210 | sNewO[ty][i * Bc + tx] += sSafeE[ty][k] * rescaleCur * sV[k][i * Bc + tx]; 211 | } 212 | } 213 | 214 | // NOTE: rescale output 215 | // old_nume = old_o * old_denom 216 | // new_o = (old_nume + new_nume) / new_denom 217 | for (int i = 0; i < groupTx; i++) { 218 | sNewO[ty][i * Bc + tx] = (/* new_nume */ sNewO[ty][i * Bc + tx] + /* old_o */sO[ty][i * Bc + tx] * rescaleOld * /* old_denom */ sDenom[ty]) / newDenom; 219 | } 220 | 221 | __syncthreads(); 222 | 223 | // update global o 224 | if (row < seqlen) { 225 | for (int i = 0; i < groupTx; i++) { 226 | // copy sO[row, dim] to gO[row, dim] 227 | O[row * dim + i * Bc + tx] = sNewO[ty][i * Bc + tx]; 228 | } 229 | } 230 | 231 | // update global max and denom 232 | gMAX[row] = newMax; 233 | gDenom[row] = newDenom; 234 | __syncthreads(); 235 | } 236 | } 237 | 238 | void self_attention_cuda(float *Q, float *K, float *V, float *O, int m, int n) { 239 | int mBlock = 2; 240 | assert(m % mBlock == 0 && "mBlock should align"); 241 | 242 | float sm_scale = 1.f / sqrtf(static_cast(n)); 243 | float *sm_o; 244 | cudaMalloc((void **)&sm_o, sizeof(float) * m * m); 245 | 246 | dim3 qk_block(m / mBlock, 1, 1); 247 | naive_nrow_gemm<<<1, qk_block>>>(Q, K, sm_o, sm_scale, 0, m, m, n, mBlock); 248 | cudaDeviceSynchronize(); 249 | DEBUG_BLOCK( 250 | CUDA_CHECK(cudaGetLastError()); 251 | printf("== naive QK ==\n"); 252 | print_device_matrix(sm_o, m, m); 253 | ); 254 | 255 | // QK[M, M] 256 | dim3 sm_block(m, 1, 1); 257 | row_softmax<<<1, sm_block>>>(sm_o, sm_o, m); 258 | cudaDeviceSynchronize(); 259 | DEBUG_BLOCK( 260 | CUDA_CHECK(cudaGetLastError()); 261 | printf("== naive softmax(QK) ==\n"); 262 | print_device_matrix(sm_o, m, m); 263 | ); 264 | 265 | // QK[M, M] @ V[M, N] 266 | dim3 qkv_block(m / mBlock, 1, 1); 267 | naive_pv<<<1, qkv_block>>>(sm_o, V, O, m, n, mBlock); 268 | cudaDeviceSynchronize(); 269 | DEBUG_BLOCK( 270 | CUDA_CHECK(cudaGetLastError()); 271 | printf("== naive softmax(QK)V ==\n"); 272 | print_device_matrix(O, m, n); 273 | ); 274 | 275 | cudaFree(sm_o); 276 | } 277 | 278 | // naive gemm implement with slice-k 279 | // perform C = aA@B + bC 280 | // A[M, K] x B[K, N] = C[M, N] 281 | // each thread process mblock rows of A 282 | __global__ void naive_nrow_gemm(float *A, float *B, float *C, float a, float b, 283 | int M, int N, int K, int mBlock) { 284 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 285 | 286 | // each thread process a range of rows 287 | idx *= mBlock; 288 | 289 | // A[mBlock, K] x B[N, K].T = C[mBlock, N] 290 | for (int i = idx; i < idx + mBlock; i++) { 291 | for (int j = 0; j < N; j++) { 292 | float sum = 0.f; 293 | for (int k = 0; k < K; k++) { 294 | sum += A[i * K + k] * B[j * K + k]; 295 | } 296 | // C[M, N] 297 | // C = aA@B + bC 298 | C[i * N + j] = a * sum + b * C[i * N + j]; 299 | } 300 | } 301 | } 302 | 303 | // perform QK[M, M] @ V[M, N] 304 | __global__ void naive_pv(float *P, float *V, float *O, int M, int N, 305 | int mBlock) { 306 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 307 | 308 | // each thread process a range of rows 309 | idx *= mBlock; 310 | 311 | int K = M; 312 | // P[mBlock, M] x V[M, N] = O[mBlock, N] 313 | for (int i = idx; i < idx + mBlock; i++) { 314 | for (int j = 0; j < N; j++) { 315 | float sum = 0.f; 316 | for (int k = 0; k < K; k++) { 317 | sum += P[i * K + k] * V[k * N + j]; 318 | } 319 | // C[M, N] 320 | O[i * N + j] = sum; 321 | } 322 | } 323 | } 324 | 325 | // each thread process one row of softmax 326 | __global__ void row_softmax(float *input, float *output, int n) { 327 | // assume id will not exceed row number of input 328 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 329 | 330 | float max = -INFINITY; 331 | float sum = 0.f; 332 | 333 | // Find max 334 | for (int i = 0; i < n; i++) { 335 | if (input[idx * n + i] > max) { 336 | max = input[idx * n + i]; 337 | } 338 | } 339 | 340 | // Compute numerator and denominator 341 | for (int i = 0; i < n; i++) { 342 | output[idx * n + i] = exp(input[idx * n + i] - max); 343 | sum += output[idx * n + i]; 344 | } 345 | 346 | // Compute softmax 347 | for (int i = 0; i < n; i++) { 348 | output[idx * n + i] /= sum; 349 | } 350 | } 351 | 352 | // print matrix 353 | void print_host_matrix(float *matrix, int m, int n) { 354 | for (int i = 0; i < m; i++) { 355 | for (int j = 0; j < n; j++) { 356 | printf("%f, ", matrix[i * n + j]); 357 | } 358 | printf("\n"); 359 | } 360 | } 361 | 362 | void print_device_matrix(float *dev_ptr, int m, int n) { 363 | float *host_ptr = new float[m * n]; 364 | cudaMemcpy(host_ptr, dev_ptr, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 365 | 366 | for (int i = 0; i < m; i++) { 367 | for (int j = 0; j < n; j++) { 368 | printf("%f, ", host_ptr[i * n + j]); 369 | } 370 | printf("\n"); 371 | } 372 | free(host_ptr); 373 | } 374 | 375 | void test_attention() { 376 | // seqlen 377 | int m = input_seq; 378 | // dim 379 | int n = dim; 380 | 381 | // Host pointer 382 | float *h_K = new float[m * n]; 383 | float *h_Q = new float[m * n]; 384 | float *h_V = new float[m * n]; 385 | float *h_O = new float[m * n]; 386 | 387 | // 初始化 K, Q, V 388 | for (int i = 0; i < m * n; ++i) { 389 | // h_K[i] = static_cast(rand()) / RAND_MAX; 390 | // h_Q[i] = static_cast(rand()) / RAND_MAX; 391 | // h_V[i] = static_cast(rand()) / RAND_MAX; 392 | h_K[i] = static_cast(i); 393 | h_Q[i] = static_cast(i); 394 | h_V[i] = static_cast(i); 395 | } 396 | 397 | printf("== K ==\n"); 398 | print_host_matrix(h_K, m, n); 399 | 400 | float *d_K, *d_Q, *d_V, *d_O; 401 | // Malloc device memory 402 | cudaMalloc((void **)&d_K, sizeof(float) * m * n); 403 | cudaMalloc((void **)&d_Q, sizeof(float) * m * n); 404 | cudaMalloc((void **)&d_V, sizeof(float) * m * n); 405 | cudaMalloc((void **)&d_O, sizeof(float) * m * n); 406 | 407 | // Copy data from host to device 408 | cudaMemcpy(d_K, h_K, sizeof(float) * m * n, cudaMemcpyHostToDevice); 409 | cudaMemcpy(d_Q, h_Q, sizeof(float) * m * n, cudaMemcpyHostToDevice); 410 | cudaMemcpy(d_V, h_V, sizeof(float) * m * n, cudaMemcpyHostToDevice); 411 | 412 | cudaEvent_t start, stop; 413 | cudaEventCreate(&start); 414 | cudaEventCreate(&stop); 415 | cudaEventRecord(start, 0); 416 | 417 | // Run test 418 | for (int i = 0; i < 1; i++) { 419 | // Launch kernel 420 | self_attention_cuda(d_Q, d_K, d_V, d_O, m, n); 421 | 422 | CUDA_CHECK(cudaGetLastError()); 423 | } 424 | 425 | // test flash attention 1 426 | cudaMemset(d_O, 0, sizeof(float) * m * n); 427 | for (int i = 0; i < 1; i++) { 428 | flash_attention_v1_cuda(d_Q, d_K, d_V, d_O, m, n); 429 | CUDA_CHECK(cudaGetLastError()); 430 | } 431 | 432 | cudaEventRecord(stop, 0); 433 | cudaEventSynchronize(stop); 434 | float milliseconds = 0; 435 | cudaEventElapsedTime(&milliseconds, start, stop); 436 | printf("Time for kernel execution: %.3f ms \n", milliseconds / 100); 437 | cudaEventDestroy(start); 438 | cudaEventDestroy(stop); 439 | 440 | // Result back to host 441 | cudaMemcpy(h_O, d_O, sizeof(float) * m * n, cudaMemcpyDeviceToHost); 442 | 443 | cudaFree(d_K); 444 | cudaFree(d_Q); 445 | cudaFree(d_V); 446 | cudaFree(d_O); 447 | free(h_Q); 448 | free(h_K); 449 | free(h_V); 450 | free(h_O); 451 | } 452 | 453 | int main() { 454 | test_attention(); 455 | 456 | return 0; 457 | } 458 | -------------------------------------------------------------------------------- /cutlass_cute_tutorial_zh.md: -------------------------------------------------------------------------------- 1 | # 用cutlass cute实现flash attention 2 | 3 | flash attention自顶向下(虽然我学cutlass是自底向上学的但是感觉快速上手应该自顶向下学)。因为有了cutlass cute用户就可以方便的实现一些功能了, 即一些cuda编程的范式: 4 | 5 | - cuda程序范式: global mem -> share mem -> reg -> compute 6 | * block tiling: 7 | + aka 复用smem, gmem -> smem的拷贝 8 | * thread tiling: 9 | + aka 复用reg, smem -> reg的拷贝 10 | * 合并访存, 向量访存: 11 | + aka 向量指令, LDSM, ldmatrix指令 12 | * warp divergent线程束分化 13 | + aka warp负载均衡, 同理流水线气泡问题 14 | * bank conflict冲突消解: swizzle 15 | + aka 利用内存的多路通道 16 | * double buffering 17 | + aka 加载和计算的流水线 18 | * ... 19 | 20 | 需要自底向上学的朋友推荐看[reed哥的系列教程](https://www.zhihu.com/people/reed-84-49) 21 | 22 | 23 | ## Acknowledge 24 | 25 | - 直接抄的flash attention的代码,但是从0写一遍抄一遍的 26 | - 排雷了不少坑 27 | - 简化了大量fa的工程考虑,只保留核心代码 28 | - 纯cuda,不考虑pybind版本可以看[standalone文件夹](https://github.com/66RING/tiny-flash-attention/tree/main/flash_attention_cutlass/standalone_src) 29 | - 太久没填坑可以直接makefile开学 30 | 31 | ## flash attention速通 32 | 33 | TODO: 简单描述一下flash attention的本质: flash attention three easy pieces 34 | 35 | - online safe softmax 36 | - 两个gemm的融合 37 | - rescale的数学原理 38 | 39 | 40 | ## 自顶向下cute flash attention 41 | 42 | 在不考虑使用cutlass的情况下, 纯cuda应该怎么写高性能算子: 43 | 44 | 1. 多维block tiling: 45 | - 把数据从global memory拷贝到shared memory 46 | - 复用smem中的数据, 减少访问gmem的此时 47 | 2. 多维thread tiling 48 | - 把数据从shared memory拷贝到global memory 49 | - 复用寄存器中的数据 50 | 3. 进一步优化 51 | 4. 使用向量指令异步加载 52 | - LDSM 53 | - ldmatrix 54 | 5. 合并访存 55 | 6. bank conflict冲突消解 56 | 7. 传算交叠流水线: 一边gmem -> smem拷贝一边做reg的gemm计算 57 | 58 | 而cutlass cute则把原本需要手写的thread协同工作的代码抽象封装好了, 如需要协同做拷贝时可以`make_tiled_copy`创建一个拷贝对象, 需要协同计算时可以用`TiledMMA`创建mma(matrix multiply accumulate)对象来做计算。 59 | 60 | **只需要看懂mma布局就知道thread间如何协同的**, 后面[基础设施](#基础设施)章节会介绍 61 | 62 | 63 | ### Terms 名词解释 64 | 65 | - 命名习惯: `tQgQ` 66 | * 看到cute的变量名可能一头雾水, 所以有必要解释一下 67 | * 如`auto tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0))`, `t`(to)表示是给什么用的, 这里只是抽象了一层还是Q本身所以直接用tQ。`g`表示该变量的位置在global memory中 68 | * 如`tSrQ`, `tSrK`表示是给attention **S**core计算使用的, 寄存器(reg)中的Q, K 69 | * 如`tOrVt`表示是给最终output用的, 寄存器中的转置过了的V 70 | - MNK矩阵乘法表述法 71 | * 两个矩阵相乘需要至少一个维度相同, K就表示这个相同的维度是多少 72 | * `A[M, K] @ B[N, K]` 73 | - MMA(matrix multiply accumulate) 74 | * 简单的说就是用于表示thread tiling的规模, 即一个thread block中用多少个thread怎么计算, cute会抽象成一个个mma对象 75 | - MMA描述法: 描述底层执行`D = AB + C`要使用的指令, 用户可以根据需要指定 76 | * 描述方法: DABC + MNK 77 | - DABC: 描述了寄存器类型, 如`SM75_16x8x8_F32F16F16F32_TN`中`F32F16F16F32`就是DABC描述。表示DABC寄存器分别是`F32`, `F16`, `F16`, `F32` 78 | - MNK: 描述了矩阵乘法的规模, 如`SM75_16x8x8_F32F16F16F32_TN`中`16x8x8`就表示`D[M, N] = A[M, K] * B[N, K] + C[M, N]` 79 | - Tiled_MMA: 描述多个MMA_Atom如何协作来完成一个大任务 80 | * AtomLayoutMNK: Tile内在MNK方向上重复几次Atom, **通过多线程重复** 81 | * ValueLayoutMNK: Atom内在MNK方向上重复几次计算, **单线程内重复计算** 82 | - BlockM 83 | * Q的分块计算的粒度 84 | - BlockN 85 | * KV的分块计算的粒度 86 | 87 | 88 | ### 基础设施 89 | 90 | - 查看MMA布局 91 | 92 | 使用这个[mma布局打印脚本](https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15)可以打印, 使用方法如下: 修改不同mma指令`SM80_16x8x16_F32F16F16F32_TN`来测试。 93 | 94 | ```cpp 95 | { 96 | auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{}, 97 | Layout>{}, // AtomLayoutMNK 98 | Layout>{} // ValLayoutMNK 99 | ); 100 | print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma); 101 | } 102 | ``` 103 | 104 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/mma.webp) 105 | 106 | 图片含义:T0, T1...表示thread,T0内V0, V1表示thread T0所负责的数据 107 | 108 | - 打印tensor 109 | 110 | 直接使用cute提供的`print_tensor`, `print_layout`可以在命令行打印出tensor数据, 方便调试。e.g. 111 | 112 | ```cpp 113 | // Convert a C pointer into cutlass Tensor 114 | // with info like shape (M, K) and stride (K, 1) 115 | const int M = 4; 116 | const int K = 8; 117 | 118 | Tensor A = make_tensor(c_host_ptr, make_shape(M, K), make_stride(K, 1)); 119 | cute::print_tensor(A); 120 | cute::print_layout(A.layout()); 121 | 122 | /* 123 | ptr[32b](0x7ffe79dcbbe0) o (4,8):(8,1): 124 | 0 1 2 3 4 5 6 7 125 | 8 9 10 11 12 13 14 15 126 | 16 17 18 19 20 21 22 23 127 | 24 25 26 27 28 29 30 31 128 | (4,8):(8,1) 129 | 0 1 2 3 4 5 6 7 130 | +----+----+----+----+----+----+----+----+ 131 | 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 132 | +----+----+----+----+----+----+----+----+ 133 | 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 134 | +----+----+----+----+----+----+----+----+ 135 | 2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 136 | +----+----+----+----+----+----+----+----+ 137 | 3 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 138 | +----+----+----+----+----+----+----+----+ 139 | */ 140 | ``` 141 | 142 | 使用`local_tile`打印一个tile(一个tensor切片) 143 | 144 | ```cpp 145 | cute::print_tensor(A); 146 | auto A00 = local_tile(A, make_tile(2, 2), make_coord(0, 0)); 147 | auto A01 = local_tile(A, make_tile(2, 2), make_coord(0, 1)); 148 | auto A10 = local_tile(A, make_tile(2, 2), make_coord(1, 0)); 149 | cute::print_tensor(A00); 150 | cute::print_tensor(A01); 151 | cute::print_tensor(A10); 152 | 153 | /* 154 | 155 | cute::print_tensor(A); 156 | ptr[32b](0x7ffc3fe94680) o (4,8):(1,4): 157 | 0 4 8 12 16 20 24 28 158 | 1 5 9 13 17 21 25 29 159 | 2 6 10 14 18 22 26 30 160 | 3 7 11 15 19 23 27 31 161 | 162 | cute::print_tensor(A00); 163 | ptr[32b](0x7ffc3fe94680) o (2,2):(1,4): 164 | 0 4 165 | 1 5 166 | 167 | cute::print_tensor(A01); 168 | ptr[32b](0x7ffc3fe946a0) o (2,2):(1,4): 169 | 8 12 170 | 9 13 171 | 172 | cute::print_tensor(A10); 173 | ptr[32b](0x7ffc3fe94688) o (2,2):(1,4): 174 | 2 6 175 | 3 7 176 | 177 | */ 178 | 179 | ``` 180 | 181 | 182 | ### attention计算的线程模型 183 | 184 | 单线程的attention计算belike: `q[seqlen, headdim] @ k[seqlen, headdim].T @ v[seqlen, headdim]` 185 | 186 | 而多线性的attention计算只需要从q的维度切分(想象成自回归场景下, 一次计算一个token的attention, 这里是并行的计算多个"单"query的attention),每个thread负责BlockM个token的single head attention计算。即 187 | 188 | 如果输入的形状为`[bs, head, seqlen, headdim]`则总线程数为`bs x head x seqlen/BlockM`, 每个thread计算`[BlockM, headdim]`的query attention计算。在bs x head维度和seqlen维度都并行。 189 | 190 | 对应到每个独立的thread block上也是同理, 开辟`bs x head x seqlen/BlockM`个独立的线程块进行多个token的并行计算。 191 | 192 | ```cpp 193 | dim3 grid(ceil_div(params.seqlen, BlockM), params.bs * params.head, 1); 194 | ``` 195 | 196 | TODO: 示意图 197 | 198 | ### 二维block tiling 199 | 200 | flash attention 2的计算流程如下图所示, Q按inner loop顺序分别和K, V分开进行计算得到partial sum, 最后将partial sum累加得到和Q形状一样的输出。伪码描述为(先不用考虑online softmax和rescale的原理) 201 | 202 | ```python 203 | flash_attention_2(): 204 | # outter loop 205 | parallel do q[NUM_BLOCK_M]: 206 | # inner loop 207 | for i in range(NUM_BLOCK_N): 208 | qk = q @ k[i].T 209 | score = online_softmax(qk) 210 | out += score @ v[i] 211 | rescale(out) 212 | ``` 213 | 214 | 你可能发现outter loop和inner loop和流传甚广的经典的flash attention那张三角形的图不一样。这是因为那张图的flash attention 1时期的实现。 215 | 216 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/flash_attention2.png) 217 | 218 | 利用cute的api可以快速制造q, k, v分块: 219 | 220 | 1. 用`make_tensor()`把裸指针封装成tensor方便后续操作 221 | 2. 使用`local_tile(tensor, tile, coord)`从tensor中取出一组/一个分块 222 | 3. 创建`Copy_Atom`拷贝对象实现global memory到shared memory的数据拷贝, 简单易用的多维block tiling 223 | 224 | 首先使用`make_tensor`API可以把传入的裸指针转换成更方便使用的Tensor。这里把完整`seqlen x dim`的QKV对象创建了出来,方便后面使用cute的API做`q_slice[i++]`之类的操作。不用担心`make_tensor`会产生额外的开销, 因为它不会。 225 | 226 | ```cpp 227 | // dim3 grid(ceil_div(params.seqlen, BlockM), params.bs * params.head, 1); 228 | 229 | const int m_block = blockIdx.x; 230 | const int bs_head_offset = blockIdx.y * params.seqlen * params.dim; 231 | 232 | Tensor Q = make_tensor( 233 | make_gmem_ptr(reinterpret_cast(params.q_ptr) + bs_head_offset), 234 | make_shape(params.seqlen, params.dim), 235 | make_stride(params.dim, Int<1>{})); 236 | Tensor K = make_tensor( 237 | make_gmem_ptr(reinterpret_cast(params.k_ptr) + bs_head_offset), 238 | make_shape(params.seqlen, params.dim), 239 | make_stride(params.dim, Int<1>{})); 240 | Tensor V = make_tensor( 241 | make_gmem_ptr(reinterpret_cast(params.v_ptr) + bs_head_offset), 242 | make_shape(params.seqlen, params.dim), 243 | make_stride(params.dim, Int<1>{})); 244 | ``` 245 | 246 | 根据block id加载thread block对应的qkv分块。`local_tile(tensor, tile, coord)`可以把tensor抽象成由多个tile组成的数组(可以多多维), 然后使用coord去索引取出需要的部分。这里取出了当前thread block负责的Q分块,并取出第一个kv分块做后续"传算交叠流水线"的prefill. 247 | 248 | 因为这里Q的shape是`seqlen, kHeadDim`, 所以拆分成多个`[kBlockM, kHeadDim]`的块后可索引的coord为`[seqlen/kBlockM, kHeadDim/kHeadDim]`。取出`[m_block, _]`, 相当于python中的`[m_block, :]`这样的索引方式, 其中`m_block`索引维度的会被squeeze, 而`_`索引的维度会保留。所以最终的shape为`(kBlockM, kHeadDim, num_tile_n=1)` 249 | 250 | ```cpp 251 | // 加载Q, K, V分块 252 | // (kBlockM, kHeadDim, num_tile_n) 253 | Tensor gQ = local_tile(Q, make_tile(Int{}, Int{}), make_coord(m_block, _)); 254 | 255 | // (kBlockN, kHeadDim, num_tile_n) 256 | // NOTE: loading流水线, 初次加载所需K, V 257 | Tensor gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(0, _)); 258 | Tensor gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(0, _)); 259 | ``` 260 | 261 | **将数据从global memory拷贝到shared memory来做多维的block tiling**: 定义从global memory到share memory拷贝的对象, 这样可以减少用户直接使用gpu指令。具体拷贝对象怎么构造后续再说, 简单的说就是使用一个config来配置用什么方法拷贝(异步的, 向量的)。 262 | 263 | ```cpp 264 | // Construct SMEM tensors. 265 | Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); 266 | Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); 267 | Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); 268 | // Tensor for V Transpose; used in GEMM-II. 269 | Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); 270 | Tensor sVtNoSwizzle = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVtNoSwizzle{}); 271 | 272 | // NOTE: 定义gmem -> smem拷贝的src, dst 273 | Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0)); 274 | Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); 275 | Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 276 | Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); 277 | Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 278 | Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); 279 | ``` 280 | 281 | 其中, `gmem_thr_copy_QKV.partition_S()`创建拷贝的源地址对象, `gmem_thr_copy_QKV.partition_D()`创建拷贝的目标地址对象。因为gQ我们在创建分块时第二个维度用满了, 所以`make_coord(m_block, _)`提取出来也只有一个元素, 直接用`0`索引掉。 282 | 283 | ``` 284 | // tQgQ: tQ: 用于(t)表示/计算Q. gQ: 是global memory上的数据 285 | // tQsQ: tQ: 用于(t)表示/计算Q. sQ: 是shared memory上的数据 286 | ``` 287 | 288 | 然后使用API即可实现一个多维数据的拷贝。 289 | 290 | ```cpp 291 | // NOTE: gmem_tiled_copy_QKV为cute抽象出来的拷贝对象Copy_Atom, 表示用一组thread来做拷贝 292 | cute::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ); 293 | // 开始执行异步拷贝 294 | cute::cp_async_fence(); 295 | ``` 296 | 297 | 具体`gmem_thr_copy_QKV`拷贝对象的构造方法后面再说, 只需要传入一个异步拷贝的参数和规模布局即可用上向量指令做异步拷贝。 298 | 299 | > 这是不是比手写gpu指令的block tiling各种拷贝简单多了: 300 | 301 | 302 | ### 二维thread tiling 303 | 304 | 本章节开始进入inner loop部分 305 | 306 | ```python 307 | flash_attention_2(): 308 | # outter loop 309 | parallel do q[NUM_BLOCK_M]: 310 | # inner loop 311 | for i in range(NUM_BLOCK_N): 312 | qk = q @ k[i].T 313 | score = online_softmax(qk) 314 | out += score @ v[i] 315 | rescale(out) 316 | ``` 317 | 318 | 整体流程如下 319 | 320 | 1. pipeline prefill: load(q), load(k[0]) 321 | 2. pipeline start 322 | 3. async_load(next(v)) && compute q @ k.T 323 | 4. softmax(qk) 324 | 5. async_load(next(k)) && compute qk @ v 325 | 6. pipeline finish 326 | 7. rescale 327 | 328 | 其中做gemm计算时都会从smem拷贝多维的数据到寄存器中做一个thread tiling。thread tiling可以复用已经拷贝到寄存器的数据,减少smem到reg拷贝的次数。如下图所示, 当gemm计算第0行时, BX0和A0X计算完成后, BX1可以直接利用已经在寄存器的A0X而不用再次做smem到reg的加载。 329 | 330 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/thread_tiling.png) 331 | 332 | 从gemm的角度出发看多维thread tiling的实现。使用`cute::copy`把smem中的数据`tCsA`拷贝到寄存器中`tCrA`后直接使用`cute::gemm`做多维thread tiling的gemm计算。具体thread tiling的布局通过可以通过[打印mma](#基础设施)查看。 333 | 334 | ```cpp 335 | template 339 | inline __device__ void gemm_smem(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 340 | Tensor4 const& tCsB, TiledMma tiled_mma, 341 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 342 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 343 | // NOTE: 构造smem -> reg拷贝的目的地址寄存器对象 344 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 345 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 346 | 347 | // NOTE: s -> reg 348 | cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); 349 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 350 | #pragma unroll 351 | for (int i = 0; i < size<2>(tCrA); ++i) { 352 | if (i < size<2>(tCrA) - 1) { 353 | cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); 354 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 355 | } 356 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 357 | } 358 | } 359 | ``` 360 | 361 | for循环前先做一次`cute::copy`是为了构造传算交叠(communication compute overlap)的流水线。即做smem->reg拷贝的同时做gemm。 362 | 363 | 回到cutlass flash attention的代码。使用cute提供的API构造gemm需要的寄存器对象。TODO: 具体`SmemCopyAtom`拷贝对象的构造方法后面再说, 只需要传入一个异步拷贝的参数和规模布局即可。 364 | 365 | 使用`partition_fragment_A`, `partition_fragment_B`, `partition_fragment_C`创建寄存器对象, 准备做thread tiling: 把数据从smem拷贝到reg, 并利用reg中的数据做矩阵乘法。 366 | 367 | ```cpp 368 | // NOTE: 定义smem -> reg拷贝的dst 369 | // partition_fragment与partition类似, 只是返回的是寄存器表示 370 | Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) 371 | Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) 372 | Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) 373 | // 创建输出的累加器accumulator output 374 | Tensor rAccOut = partition_fragment_C(tiled_mma, Shape, Int>{}); 375 | 376 | // NOTE: 准备拷贝Q, K, V到smem的copy对象 377 | 378 | // 创建smem -> reg的拷贝对象 379 | auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); 380 | // 根据thread id找到当前thread负责的部分 381 | auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); 382 | // 用partition_S创建smem -> reg的源地址对象 383 | Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); 384 | ... 385 | ``` 386 | 387 | inner loop部分代码如下。其中, 创建`auto rAccScore = partition_fragment_C()`来**融合两个gemm**: `score = q@k.T`的gemm和`out = score @ v`的gemm。 388 | 389 | 需要注意**融合两个gemm的坑点**, 因为要融合两个gemm, gemm-I的输出`score = q@k.T`要作为第二个gemm-II的输入`out = score @ v`, 所以**gemm-I的输出C layout需要和gemm-II的输入A layout一致**才能直接使用。通过打印mma指令发现`SM80_16x8x16_F32F16F16F32_TN`就符合这种要求。 390 | 391 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/mma.webp) 392 | 393 | [ColfaxResearch的实现](https://github.com/ColfaxResearch/cutlass-kernels/blob/c796d779c9991213252e9f0a07e5516c8d829e3f/src/fmha/fmha_forward.cu#L114)似乎不用考虑这点, 用`rs_op_selector`和`ss_op_selector`两个API就把MMA配置好了。如果有人知道是怎么回事pls let me know. 394 | 395 | 396 | ```cpp 397 | /* 398 | flash_attention_2(): 399 | # outter loop 400 | parallel do q[NUM_BLOCK_M]: 401 | # inner loop 402 | for i in range(NUM_BLOCK_N): 403 | qk = q @ k[i].T 404 | score = online_softmax(qk) 405 | out += score @ v[i] 406 | rescale(out) 407 | */ 408 | for (int nbi = n_block_min; nbi < n_block_max; nbi++) { 409 | auto rAccScore = partition_fragment_C(tiled_mma, make_shape(Int{}, Int{})); 410 | 411 | clear(rAccScore); 412 | 413 | // 等待Q, K的gmem -> smem拷贝完成, 即Q, K就绪 414 | // wait<0>表示等待还剩0个未完成 415 | flash::cp_async_wait<0>(); 416 | __syncthreads(); 417 | 418 | // gemm的同时异步加载V 419 | gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 420 | tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 421 | // 异步加载V到smem 422 | flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV); 423 | // 发起异步拷贝 424 | cute::cp_async_fence(); 425 | 426 | // O = Q@K.T 427 | // NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile** 428 | flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, 429 | smem_thr_copy_Q, smem_thr_copy_K 430 | ); 431 | 432 | Tensor scores = make_tensor(rAccScore.data(), flash::convert_layout_acc_rowcol(rAccScore.layout())); 433 | 434 | // NOTE: 2. mask within N BLOCKs 435 | if (Is_causal == true && nbi * kBlockN >= seqlen_start) { 436 | flash::mask_within_nblock(scores, m_block, nbi); 437 | } 438 | 439 | // NOTE: 等待V加载完成, 为下个K加载准备初始状态 440 | flash::cp_async_wait<0>(); 441 | __syncthreads(); 442 | 443 | // advance K 444 | if (nbi != n_block_max - 1) { 445 | gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 446 | tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 447 | flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK); 448 | cute::cp_async_fence(); 449 | } 450 | 451 | // 计算softmax 452 | // NOTE: rAccOut记录softmax后所有的分子 453 | nbi == 0 ? flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale) : 454 | flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale); 455 | 456 | // 实际执行QK @ V 457 | // (score AKA rAccScore): QK[M, N] @ V[N, dim] 458 | // NOTE: DABC: F32F16F16F32, convert D type(F32) to A type(F16) 459 | // TODO: convert_type目前写死 460 | Tensor rP = flash::convert_type_f32_to_f16(rAccScore); 461 | // NOTE: Convert from layout C to layout A 462 | Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); 463 | 464 | flash::gemm_A_in_regs(rAccOut, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); 465 | } 466 | ``` 467 | 468 | 伪码和代码的对应情况如下: 469 | 470 | ```python 471 | # inner loop 472 | for nbi in range(NUM_BLOCK_N): 473 | # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 474 | qk = q @ k[nbi].T # flash::gemm_smem() 475 | score = online_softmax(qk) # softmax_rescale_o() 476 | # v[nbi]: gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 477 | out += score @ v[nbi] # gemm_A_in_regs() 478 | ``` 479 | 480 | ### 传算交叠流水线 481 | 482 | - 异步拷贝 483 | 484 | 创建gmem到smem的拷贝对象时使用`SM80_CP_ASYNC_CACHEGLOBAL`指令来创建异步拷贝的Copy atom对象。 485 | 486 | ```cpp 487 | using Gmem_copy_struct = std::conditional_t< 488 | Has_cp_async, 489 | SM80_CP_ASYNC_CACHEGLOBAL, 490 | DefaultCopy 491 | >; 492 | using GmemTiledCopyQKV = decltype( 493 | make_tiled_copy(Copy_Atom{}, 494 | GmemLayoutAtom{}, 495 | Layout>{})); // Val layout, 8 vals per read 496 | ``` 497 | 498 | 499 | - 流水线 500 | 501 | 伪码描述如下, 计算q@k时可以加载v, 计算qk@v时加载下一次迭代需要的k。目前只是用double buffering的方式预取1个kv. 如果每次预取多个kv还需要考虑smem大小对性能的影响。 502 | 503 | ```python 504 | # inner loop 505 | async_load(k[0]) # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 506 | for nbi in range(NUM_BLOCK_N): 507 | # 加载v的同时计算q@k 508 | async_load(v[nbi]) # v[nbi]: gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 509 | qk = q @ k[nbi].T # flash::gemm_smem() 510 | score = online_softmax(qk) # softmax_rescale_o() 511 | 512 | # 计算qk @ v的同时加载下一次迭代需要的k 513 | async_load(k[nbi]) # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 514 | out += score @ v[nbi] # gemm_A_in_regs() 515 | ``` 516 | 517 | 在cutlass cute中使用也很简单, 构造好异步拷贝对象后发起异步拷贝即可。 518 | 519 | ```cpp 520 | // gemm的同时异步加载V 521 | gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 522 | tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 523 | // 异步加载V到smem 524 | flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV); 525 | // 发起异步拷贝 526 | cute::cp_async_fence(); 527 | 528 | // NOTE: 拷贝的同时执行gemm 529 | // O = Q@K.T 530 | // NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile** 531 | flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, 532 | smem_thr_copy_Q, smem_thr_copy_K 533 | ); 534 | ``` 535 | 536 | 537 | ### 其他细节 538 | 539 | - causal模式的提前返回 540 | * block间早退 541 | * **block内mask**: thread在mma中的定位 542 | - 结果拷贝回global memory返回 543 | * 同样利用smem, 先从reg拷贝到smem再从smem拷贝到gmem 544 | * 这样可以用更大的位宽 545 | - online safe softmax 546 | - pybind和模板展开 547 | * 官方实现用了很多模板,本质就是1. 枚举所有可能的分块策略 2. 每个config写一个文件加速编译 3. 每个模板写个文件微调最佳config 548 | * python中接入cpp代码可以看这个[仓库](https://github.com/66RING/pytorch-cuda-binding-tutorial) 549 | 550 | 后面再展开补充,感兴趣的朋友可以先看源码注释。 551 | 552 | ### 其他优化 553 | 554 | - bank conflict重复避免 555 | * swizzle 556 | * cutlass cute封装好了用swizzle解决bank conflict, 在创建拷贝对象时使用即可 557 | - 转置优化 558 | * 拷贝时直接拷贝到转换后的目标地址, 从而不必开辟新的空间 559 | * 创建拷贝对象时, 配置布局时把dst的布局转置掉即可 560 | - [高性能的reduce实现](https://developer.nvidia.com/blog/faster-parallel-reductions-kepler/) 561 | * 优化线程束分化问题(warp divergent) 562 | 563 | TODO: 细节展开 564 | 565 | ### 稍微一点自底向上 566 | 567 | > 深入的自底向上可以看[reed哥的系列教程](https://www.zhihu.com/people/reed-84-49) 568 | 569 | TODO: 挑选几个重要的 570 | 571 | 572 | ### 主要坑点 573 | 574 | - 两个gemm的融合的layout问题: gemm-I, gemm-II 575 | * 输入输出的布局比较讲究: gemm-I的输出C layout要和gemm-II的输入A layout一致 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | -------------------------------------------------------------------------------- /cutlass_cute_tutorial_en.md: -------------------------------------------------------------------------------- 1 | 2 | # Reproduce Flash Attention with Cutlass Cute 3 | 4 | Flash attention can be learned from the top-down (although I learned Cutlass from the bottom-up, but I feel that it should be learned from the top-down for quick mastery). With the help of Cutlass Cute, users can conveniently implement some functionalities, that is, some paradigms of CUDA programming: 5 | 6 | - CUDA Programming Paradigms: global mem -> share mem -> reg -> compute 7 | * Block Tiling: 8 | + aka reusing shared memory (smem), copying from global memory (gmem) to shared memory (smem) 9 | * Thread Tiling: 10 | + aka reusing registers (reg) and shared memory (smem), copying from shared memory (smem) to registers (reg) 11 | * Merging Memory Accesses, Vector Memory Accesses: 12 | + aka vector instructions, LDSM (Load Shared Memory), ldmatrix instructions 13 | * Warp Divergence: 14 | + aka warp load balancing, similarly to pipeline bubble issues 15 | * Bank Conflict Resolution: Swizzle 16 | + aka utilizing the multi-channel nature of memory 17 | * Double Buffering 18 | + aka the pipeline of loading and computing 19 | * ... 20 | 21 | For those who need to learn from the bottom-up, I recommend reading the [offical cute tutorial](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/00_quickstart.md) 22 | 23 | ## Acknowledge 24 | 25 | - Base on offical flash attention, but rewrote it from scratch, eliminating many hidden tricks. 26 | - I simplified a lot of the engineering considerations of the original code, keeping only the core parts. 27 | - purely in CUDA, without considering the Pybind version: You can find it in the [standalone folder](https://link.zhihu.com/?target=https%3A//github.com/66RING/tiny-flash-attention/tree/main/flash_attention_cutlass/standalone_src). 28 | - It's been a while since I've worked on it, so you can directly use the Makefile to start learning. 29 | 30 | 31 | ## flash attention three easy pieces 32 | 33 | TODO: Briefly describe the essence of flash attention: flash attention three easy pieces 34 | 35 | - Online safe softmax 36 | - Fusion of two GEMMs 37 | - Mathematical principles of rescaling 38 | 39 | 40 | ## Top-down Cute Flash Attention 41 | 42 | When considering writing high-performance operators in pure CUDA without using Cutlass, here's how to approach it: 43 | 44 | 1. Multi-dimensional block tiling: 45 | - Copy data from global memory to shared memory. 46 | - Reuse data in shared memory to reduce global memory accesses. 47 | 2. Multi-dimensional thread tiling: 48 | - Copy data from shared memory to global memory. 49 | - Reuse data in registers. 50 | 3. Further optimization. 51 | 4. Use vector instruction asynchronous loading: 52 | - LDSM 53 | - ldmatrix 54 | 5. Merge memory accesses. 55 | 6. Resolve bank conflicts. 56 | 7. Compute overlap pipelining: Copy data from global memory to shared memory while performing register-based GEMM calculations. 57 | 58 | However, Cutlass Cute abstracts and encapsulates the code that originally needs to be handwritten for thread cooperation. For example, when cooperation is needed for copying, you can use make_tiled_copy to create a copy object, and when cooperation is needed for calculation, you can use `TiledMMA` to create MMA (matrix multiply accumulate) objects for calculation. 59 | 60 | **Understanding MMA layout is sufficient to understand how threads cooperate.** The following [Tools](#Tools) section will introduce this. 61 | 62 | 63 | ### Term Explanations 64 | 65 | - **Naming Convention**: tQgQ 66 | * You might be puzzled by the cute variable names, so it's necessary to explain. 67 | * For example, `auto tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0))`, `t`(to) indicates its purpose, here it's just an abstraction layer or Q itself, so it's directly named `tQ`. `g` indicates the position of the variable in global memory. 68 | * For instance, `tSrQ`, `tSrK` indicates it's used for Score computation in the register (reg), Q, K. 69 | * For example, `tOrVt` indicates it's used for the final output, V transposed in the register. 70 | - MNK Matrix Multiplication Notation 71 | * Two matrices need at least one dimension to be the same, K represents this common dimension. 72 | * `A[M, K] @ B[N, K]` 73 | - MMA (Matrix Multiply Accumulate) 74 | * Simply put, it's used to represent the scale of thread tiling, i.e., how many threads are used in a thread block and how they compute. Cute abstracts it as individual MMA objects. 75 | - MMA Description: Describes the instructions used for the underlying execution of D = AB + C, users can specify as needed. 76 | * Description format: DABC + MNK 77 | * DABC: Describes the register type, such as in `SM75_16x8x8_F32F16F16F32_TN`, F32F16F16F32 is the DABC description. It indicates that DABC registers are F32, F16, F16, F32. 78 | * MNK: Describes the scale of matrix multiplication, like in `SM75_16x8x8_F32F16F16F32_TN`, 16x8x8 indicates `D[M, N] = A[M, K] * B[N, K] + C[M, N]`. 79 | - Tiled_MMA: Describes how multiple MMA_Atom cooperate to complete a large task. 80 | * AtomLayoutMNK: Repeats Atom in MNK direction inside a tile, repeated by multiple threads. 81 | * ValueLayoutMNK: Repeats calculations in MNK direction inside an Atom, repeated within a single thread. 82 | - BlockM 83 | * Granularity of block computation for Q. 84 | - BlockN 85 | * Granularity of block computation for KV. 86 | 87 | 88 | ### Tools 89 | 90 | - Print MMA Layout 91 | 92 | You can use this [MMA layout printing script](https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15) to print the layout. Usage is as follows: modify different MMA instructions like `SM80_16x8x16_F32F16F16F32_TN` for testing. 93 | 94 | ```cpp 95 | { 96 | auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{}, 97 | Layout>{}, // AtomLayoutMNK 98 | Layout>{} // ValLayoutMNK 99 | ); 100 | print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma); 101 | } 102 | ``` 103 | 104 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/mma.webp) 105 | 106 | Meaning of the image: T0, T1... represents threads, V0, V1 within T0 represent the data that thread T0 is responsible for. 107 | 108 | - Printing Tensors 109 | 110 | You can directly use print_tensor and print_layout provided by Cute to print tensor data in the command line for debugging. For example: 111 | 112 | ```cpp 113 | // Convert a C pointer into cutlass Tensor 114 | // with info like shape (M, K) and stride (K, 1) 115 | const int M = 4; 116 | const int K = 8; 117 | 118 | Tensor A = make_tensor(c_host_ptr, make_shape(M, K), make_stride(K, 1)); 119 | cute::print_tensor(A); 120 | cute::print_layout(A.layout()); 121 | 122 | /* 123 | ptr[32b](0x7ffe79dcbbe0) o (4,8):(8,1): 124 | 0 1 2 3 4 5 6 7 125 | 8 9 10 11 12 13 14 15 126 | 16 17 18 19 20 21 22 23 127 | 24 25 26 27 28 29 30 31 128 | (4,8):(8,1) 129 | 0 1 2 3 4 5 6 7 130 | +----+----+----+----+----+----+----+----+ 131 | 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 132 | +----+----+----+----+----+----+----+----+ 133 | 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 134 | +----+----+----+----+----+----+----+----+ 135 | 2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 136 | +----+----+----+----+----+----+----+----+ 137 | 3 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 138 | +----+----+----+----+----+----+----+----+ 139 | */ 140 | ``` 141 | 142 | Use local_tile to print a tile (a slice of a tensor). 143 | 144 | ```cpp 145 | cute::print_tensor(A); 146 | auto A00 = local_tile(A, make_tile(2, 2), make_coord(0, 0)); 147 | auto A01 = local_tile(A, make_tile(2, 2), make_coord(0, 1)); 148 | auto A10 = local_tile(A, make_tile(2, 2), make_coord(1, 0)); 149 | cute::print_tensor(A00); 150 | cute::print_tensor(A01); 151 | cute::print_tensor(A10); 152 | 153 | /* 154 | 155 | cute::print_tensor(A); 156 | ptr[32b](0x7ffc3fe94680) o (4,8):(1,4): 157 | 0 4 8 12 16 20 24 28 158 | 1 5 9 13 17 21 25 29 159 | 2 6 10 14 18 22 26 30 160 | 3 7 11 15 19 23 27 31 161 | 162 | cute::print_tensor(A00); 163 | ptr[32b](0x7ffc3fe94680) o (2,2):(1,4): 164 | 0 4 165 | 1 5 166 | 167 | cute::print_tensor(A01); 168 | ptr[32b](0x7ffc3fe946a0) o (2,2):(1,4): 169 | 8 12 170 | 9 13 171 | 172 | cute::print_tensor(A10); 173 | ptr[32b](0x7ffc3fe94688) o (2,2):(1,4): 174 | 2 6 175 | 3 7 176 | 177 | */ 178 | 179 | ``` 180 | 181 | 182 | ### Thread Model for Flash Attention 183 | 184 | The single-threaded attention computation belike: `q[seqlen, headdim] @ k[seqlen, headdim].T @ v[seqlen, headdim]`. 185 | 186 | While multi-linear attention computation only requires slicing from the dimension of q (imagine in autoregressive scenarios, computing attention for one token at a time, here it's parallel computing for multiple "single" queries' attention), each thread is responsible for calculating single-head attention for BlockM tokens. That is, 187 | 188 | If the input shape is `[bs, head, seqlen, headdim]`, the total number of threads is `bs x head x seqlen/BlockM`, and each thread computes `[BlockM, headdim]` query attention calculation. This is parallel in both the bs x head dimension and the seqlen dimension. 189 | 190 | Similarly, for each independent thread block, `bs x head x seqlen/BlockM` independent thread blocks are allocated to perform parallel computation for multiple tokens. 191 | 192 | ```cpp 193 | dim3 grid(ceil_div(params.seqlen, BlockM), params.bs * params.head, 1); 194 | ``` 195 | 196 | TODO: graph 197 | 198 | 199 | ### 2D Block Tiling 200 | 201 | The computation process of Flash Attention 2 is illustrated in the following diagram. Q is calculated separately with K and V in inner loop order to obtain partial sums. Finally, the partial sums are accumulated to get an output of the same shape as Q. The pseudocode description (without considering the principles of online softmax and rescale) is as follows. 202 | 203 | ```python 204 | flash_attention_2(): 205 | # outter loop 206 | parallel do q[NUM_BLOCK_M]: 207 | # inner loop 208 | for i in range(NUM_BLOCK_N): 209 | qk = q @ k[i].T 210 | score = online_softmax(qk) 211 | out += score @ v[i] 212 | rescale(out) 213 | ``` 214 | 215 | You may notice that the outer loop and inner loop are different from the widely circulated classic Flash Attention triangle diagram. This is because that diagram is from the Flash Attention 1 implementation. 216 | 217 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/flash_attention2.png) 218 | 219 | Using Cute's API, we can quickly create blocks for q, k, v: 220 | 221 | - Use `make_tensor()` to wrap raw pointers into tensors for easier subsequent operations. 222 | - Use `local_tile(tensor, tile, coord)` to extract a group/one block from the tensor. 223 | - Create a `Copy_Atom` copy object to implement data copying from global memory to shared memory, which provides simple and easy-to-use multi-dimensional block tiling. 224 | 225 | First, the `make_tensor` API is used to convert the passed raw pointer into a more convenient Tensor. Here, a complete `seqlen x dim` QKV object is created, making it convenient to use Cute's API for operations like `q_slice[i++]`. Don't worry about additional overhead from `make_tensor` because it doesn't create any. 226 | 227 | ```cpp 228 | // dim3 grid(ceil_div(params.seqlen, BlockM), params.bs * params.head, 1); 229 | 230 | const int m_block = blockIdx.x; 231 | const int bs_head_offset = blockIdx.y * params.seqlen * params.dim; 232 | 233 | Tensor Q = make_tensor( 234 | make_gmem_ptr(reinterpret_cast(params.q_ptr) + bs_head_offset), 235 | make_shape(params.seqlen, params.dim), 236 | make_stride(params.dim, Int<1>{})); 237 | Tensor K = make_tensor( 238 | make_gmem_ptr(reinterpret_cast(params.k_ptr) + bs_head_offset), 239 | make_shape(params.seqlen, params.dim), 240 | make_stride(params.dim, Int<1>{})); 241 | Tensor V = make_tensor( 242 | make_gmem_ptr(reinterpret_cast(params.v_ptr) + bs_head_offset), 243 | make_shape(params.seqlen, params.dim), 244 | make_stride(params.dim, Int<1>{})); 245 | ``` 246 | 247 | Load the QKV block corresponding to the thread block according to the block ID. `local_tile(tensor, tile, coord)` abstracts the tensor into an array composed of multiple tiles (in multiple dimensions), and then uses the coord to index and extract the required portion. Here, the Q block responsible for the current thread block is extracted, and the first KV block is extracted for subsequent "compute overlap pipelining" prefill. 248 | 249 | Since the shape of Q here is `seqlen`, `kHeadDim`, splitting it into multiple `[kBlockM, kHeadDim]` blocks allows indexing with `coord` as `[seqlen/kBlockM, kHeadDim/kHeadDim]`. Extracting `[m_block, _]` is equivalent to indexing like `[m_block, :]` in Python. Here, the dimension indexed by `m_block` will be squeezed, while the dimension indexed by _ will be retained. So, the final shape is `(kBlockM, kHeadDim, num_tile_n=1)`. 250 | 251 | ```cpp 252 | // load q, k, v block 253 | // (kBlockM, kHeadDim, num_tile_n) 254 | Tensor gQ = local_tile(Q, make_tile(Int{}, Int{}), make_coord(m_block, _)); 255 | 256 | // (kBlockN, kHeadDim, num_tile_n) 257 | // NOTE: compute commu overlap pipeline load first q, k 258 | Tensor gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(0, _)); 259 | Tensor gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(0, _)); 260 | ``` 261 | 262 | **Copying data from global memory to shared memory for multi-dimensional block tiling**: Define an object for copying from global memory to shared memory, which reduces the need for users to directly use GPU instructions. The construction of the copy object will be discussed later, but in simple terms, it's configured using a config to specify the method of copying (asynchronous, vectorized). 263 | 264 | ```cpp 265 | // Construct SMEM tensors. 266 | Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); 267 | Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); 268 | Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); 269 | // Tensor for V Transpose; used in GEMM-II. 270 | Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); 271 | Tensor sVtNoSwizzle = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVtNoSwizzle{}); 272 | 273 | // NOTE: define object of gmem -> smem copy, src, dst 274 | Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0)); 275 | Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); 276 | Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 277 | Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); 278 | Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 279 | Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); 280 | ``` 281 | 282 | In this process, `gmem_thr_copy_QKV.partition_S()` creates the source address object for copying, while `gmem_thr_copy_QKV.partition_D()` creates the destination address object. Since we've fully utilized the second dimension when creating the block for gQ, the extraction with `make_coord(m_block, _)` results in only one element, so we directly use `0` to index it. 283 | 284 | ``` 285 | // tQgQ: tQ: used for (t) calculating Q. gQ: data in global memory 286 | // tQsQ: tQ: used for (t) calculating Q. sQ: data in shared memory 287 | ``` 288 | 289 | Then, using the API, a multi-dimensional data copy can be achieved. 290 | 291 | ```cpp 292 | // NOTE: gmem_tiled_copy_QKV为cute抽象出来的拷贝对象Copy_Atom, 表示用一组thread来做拷贝 293 | cute::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ); 294 | // start async copy 295 | cute::cp_async_fence(); 296 | ``` 297 | 298 | The construction method for the `gmem_thr_copy_QKV` copy object will be discussed later. You only need to pass in an asynchronous copy parameter and the scale layout to use vector instructions for asynchronous copying. 299 | 300 | > much simpler than manually writing GPU instructions for block tiling and various copies 301 | 302 | 303 | ### 2D Thread tiling 304 | 305 | We are now entering the inner loop part of this section. 306 | 307 | ```python 308 | flash_attention_2(): 309 | # outter loop 310 | parallel do q[NUM_BLOCK_M]: 311 | # inner loop 312 | for i in range(NUM_BLOCK_N): 313 | qk = q @ k[i].T 314 | score = online_softmax(qk) 315 | out += score @ v[i] 316 | rescale(out) 317 | ``` 318 | 319 | The overall process is as follows: 320 | 321 | 1. pipeline prefill: load(q), load(k[0]) 322 | 2. pipeline start 323 | 3. async_load(next(v)) && compute q @ k.T 324 | 4. softmax(qk) 325 | 5. async_load(next(k)) && compute qk @ v 326 | 6. pipeline finish 327 | 7. rescale 328 | 329 | During the gemm calculation, multi-dimensional data is copied from shared memory to registers for thread tiling. Thread tiling allows reusing data already copied to registers, reducing the number of copies from shared memory to registers. As shown in the diagram below, when calculating the first row of the gemm, after BX0 and A0X calculations are completed, BX1 can directly use A0X already in registers without loading it again from shared memory to registers. 330 | 331 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/thread_tiling.png) 332 | 333 | Looking at the implementation of multi-dimensional thread tiling from the perspective of gemm, we use `cute::copy` to copy the data tCsA from shared memory to registers tCrA, and then directly use `cute::gemm` to perform gemm calculation with multi-dimensional thread tiling. The specific layout of thread tiling can be viewed through printing [mma](#tools). 334 | 335 | ```cpp 336 | template 340 | inline __device__ void gemm_smem(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 341 | Tensor4 const& tCsB, TiledMma tiled_mma, 342 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 343 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 344 | // NOTE: construct dst object of smem -> reg copy 345 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 346 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 347 | 348 | // NOTE: s -> reg 349 | cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); 350 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 351 | #pragma unroll 352 | for (int i = 0; i < size<2>(tCrA); ++i) { 353 | if (i < size<2>(tCrA) - 1) { 354 | cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); 355 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 356 | } 357 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 358 | } 359 | } 360 | ``` 361 | 362 | Before the for loop, we perform a `cute::copy` to construct a communication-compute overlap pipeline. This means doing smem->reg copy while performing gemm. 363 | 364 | Returning to the Cutlass Flash Attention code, we use the API provided by Cute to construct the register objects needed for gemm. TODO: The specific construction method for the SmemCopyAtom copy object will be discussed later, but you only need to pass in an asynchronous copy parameter and the scale layout. 365 | 366 | Use `partition_fragment_A, partition_fragment_B, partition_fragment_C` to create register objects, preparing for thread tiling: copying data from shared memory to registers, and performing matrix multiplication using data in registers. 367 | 368 | ```cpp 369 | // NOTE: construct mma object in register 370 | // partition_fragment can create a object in register 371 | Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) 372 | Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) 373 | Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) 374 | // construct output accumulator 375 | Tensor rAccOut = partition_fragment_C(tiled_mma, Shape, Int>{}); 376 | 377 | // construct copy object of smem -> reg 378 | auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); 379 | // select thread work by thread id 380 | auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); 381 | // use partition_S to construct src object of Copy_Atom 382 | Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); 383 | ... 384 | ``` 385 | 386 | The inner loop code is as follows. Here, `auto rAccScore = partition_fragment_C()` is created to **fuse two gemms**: the gemm for `score = q@k.T` and the gemm for `out = score @ v`. 387 | 388 | 389 | It's important to note the **pitfalls of fusing two gemms**. Because we need to fuse two gemms, the output of gemm-I, `score = q@k.T`, needs to be used as the input of gemm-II, `out = score @ v`, **so the C layout of gemm-I's output needs to match the A layout of gemm-II's input in order to be directly used**. It's found through printing MMA instructions that `SM80_16x8x16_F32F16F16F32_TN` meets this requirement. 390 | 391 | ![](https://raw.githubusercontent.com/66RING/66RING/master/.github/images/Notes/universe/ml/cutlass_flash_attention_top_down/mma.webp) 392 | 393 | [ColfaxResearch's implementation](https://github.com/ColfaxResearch/cutlass-kernels/blob/c796d779c9991213252e9f0a07e5516c8d829e3f/src/fmha/fmha_forward.cu#L114) seems to handle this without considering this point, using `rs_op_selector` and `ss_op_selector` APIs to configure MMA. If someone knows how it works, please let me know. 394 | 395 | ```cpp 396 | /* 397 | flash_attention_2(): 398 | # outter loop 399 | parallel do q[NUM_BLOCK_M]: 400 | # inner loop 401 | for i in range(NUM_BLOCK_N): 402 | qk = q @ k[i].T 403 | score = online_softmax(qk) 404 | out += score @ v[i] 405 | rescale(out) 406 | */ 407 | for (int nbi = n_block_min; nbi < n_block_max; nbi++) { 408 | auto rAccScore = partition_fragment_C(tiled_mma, make_shape(Int{}, Int{})); 409 | 410 | clear(rAccScore); 411 | 412 | // 等待Q, K的gmem -> smem拷贝完成, 即Q, K就绪 413 | // wait<0>表示等待还剩0个未完成 414 | flash::cp_async_wait<0>(); 415 | __syncthreads(); 416 | 417 | // gemm的同时异步加载V 418 | gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 419 | tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 420 | // 异步加载V到smem 421 | flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV); 422 | // 发起异步拷贝 423 | cute::cp_async_fence(); 424 | 425 | // O = Q@K.T 426 | // NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile** 427 | flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, 428 | smem_thr_copy_Q, smem_thr_copy_K 429 | ); 430 | 431 | Tensor scores = make_tensor(rAccScore.data(), flash::convert_layout_acc_rowcol(rAccScore.layout())); 432 | 433 | // NOTE: 2. mask within N BLOCKs 434 | if (Is_causal == true && nbi * kBlockN >= seqlen_start) { 435 | flash::mask_within_nblock(scores, m_block, nbi); 436 | } 437 | 438 | // NOTE: 等待V加载完成, 为下个K加载准备初始状态 439 | flash::cp_async_wait<0>(); 440 | __syncthreads(); 441 | 442 | // advance K 443 | if (nbi != n_block_max - 1) { 444 | gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 445 | tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0)); 446 | flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK); 447 | cute::cp_async_fence(); 448 | } 449 | 450 | // 计算softmax 451 | // NOTE: rAccOut记录softmax后所有的分子 452 | nbi == 0 ? flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale) : 453 | flash::softmax_rescale_o(scores, scores_max, scores_sum, rAccOut, params.softmax_scale); 454 | 455 | // 实际执行QK @ V 456 | // (score AKA rAccScore): QK[M, N] @ V[N, dim] 457 | // NOTE: DABC: F32F16F16F32, convert D type(F32) to A type(F16) 458 | // TODO: convert_type目前写死 459 | Tensor rP = flash::convert_type_f32_to_f16(rAccScore); 460 | // NOTE: Convert from layout C to layout A 461 | Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); 462 | 463 | flash::gemm_A_in_regs(rAccOut, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); 464 | } 465 | ``` 466 | 467 | The correspondence between pseudocode and code is as follows: 468 | 469 | ```python 470 | # inner loop 471 | for nbi in range(NUM_BLOCK_N): 472 | # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 473 | qk = q @ k[nbi].T # flash::gemm_smem() 474 | score = online_softmax(qk) # softmax_rescale_o() 475 | # v[nbi]: gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 476 | out += score @ v[nbi] # gemm_A_in_regs() 477 | ``` 478 | 479 | ### Communication Compute Overlap Pipeline 480 | 481 | - Asynchronous Copy 482 | 483 | When creating the copy object from global memory to shared memory, use the `SM80_CP_ASYNC_CACHEGLOBAL` instruction to create an asynchronous Copy_Atom object. 484 | 485 | ```cpp 486 | using Gmem_copy_struct = std::conditional_t< 487 | Has_cp_async, 488 | SM80_CP_ASYNC_CACHEGLOBAL, 489 | DefaultCopy 490 | >; 491 | using GmemTiledCopyQKV = decltype( 492 | make_tiled_copy(Copy_Atom{}, 493 | GmemLayoutAtom{}, 494 | Layout>{})); // Val layout, 8 vals per read 495 | ``` 496 | 497 | 498 | - Pipeline 499 | 500 | The pseudocode is as follows: when computing q@k, load v, and when computing qk@v, load the next iteration's required k. Currently, only double buffering is used to prefetch 1 set of kv. If prefetching multiple sets of kv each time, it's necessary to consider the impact of shared memory size on performance. 501 | 502 | ```python 503 | # inner loop 504 | async_load(k[0]) # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 505 | for nbi in range(NUM_BLOCK_N): 506 | # 加载v的同时计算q@k 507 | async_load(v[nbi]) # v[nbi]: gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 508 | qk = q @ k[nbi].T # flash::gemm_smem() 509 | score = online_softmax(qk) # softmax_rescale_o() 510 | 511 | # 计算qk @ v的同时加载下一次迭代需要的k 512 | async_load(k[nbi]) # k[nbi]: gK = local_tile(K, make_tile(Int{}, Int{}), make_coord(nbi + 1, _)); 513 | out += score @ v[nbi] # gemm_A_in_regs() 514 | ``` 515 | 516 | Using this in Cutlass Cute is also straightforward. Once the asynchronous copy object is constructed, initiate the asynchronous copy. 517 | 518 | ```cpp 519 | // gemm的同时异步加载V 520 | gV = local_tile(V, make_tile(Int{}, Int{}), make_coord(nbi, _)); 521 | tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0)); 522 | // 异步加载V到smem 523 | flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV); 524 | // 发起异步拷贝 525 | cute::cp_async_fence(); 526 | 527 | // NOTE: 拷贝的同时执行gemm 528 | // O = Q@K.T 529 | // NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile** 530 | flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, 531 | smem_thr_copy_Q, smem_thr_copy_K 532 | ); 533 | ``` 534 | 535 | 536 | ### Other Details 537 | 538 | - Early Return in Causal Mode 539 | * Inter-block early exit 540 | * **Intra-block Masking**: Locating threads in MMA (Matrix Multiply and Accumulate) 541 | - Copying Results Back to Global Memory 542 | * Utilizing Shared Memory (smem), first copying from registers to smem and then from smem to global memory (gmem) 543 | * This allows the use of wider bit widths 544 | - Online Safe Softmax 545 | - Pybind and Template Expansion 546 | * The official implementation uses many templates, essentially: 547 | 1. Enumerating all possible block partitioning strategies 548 | 2. Writing a file for each configuration to accelerate compilation 549 | 3. Writing a file for each template to fine-tune the best configuration 550 | - To integrate CPP code into Python, you can refer to this [repository](https://github.com/66RING/pytorch-cuda-binding-tutorial) 551 | 552 | Further details will be added later. Interested readers can first look into the source code comments. 553 | 554 | ### Other Optimizations 555 | 556 | - Bank Conflict Avoding 557 | * Swizzling 558 | * Cutlass has encapsulated swizzle to solve bank conflicts. Use it when creating copy objects. 559 | - Transpose Optimization 560 | * Copy directly to the transposed destination address, avoiding the need to allocate new space 561 | * When creating copy objects, configure the layout to transpose the destination (dst) layout 562 | - [High-performance Reduce Implementation](https://developer.nvidia.com/blog/faster-parallel-reductions-kepler/) 563 | * Optimizing warp divergence 564 | - 565 | TODO: Expand on details 566 | 567 | 568 | ### A Little Bit Bottom-Up 569 | 570 | For in-depth bottom-up understanding, refer to [offical cute tutorials](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/00_quickstart.md) 571 | 572 | TODO: Pick a few important points 573 | 574 | 575 | ### Major Trick 576 | 577 | - Fusion of two gemms: gemm-I, gemm-II 578 | - The layout of input and output is critical: gemm-I's output C layout must match gemm-II's input A layout 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | --------------------------------------------------------------------------------