├── media └── images │ └── dmha.png ├── performance ├── H20 │ ├── seq_bandwidth.png │ ├── batch_bandwidth.png │ ├── seq_throughput.png │ └── batch_throughput.png └── RTX3090 │ ├── seq_throughput.png │ └── batch_throughput.png ├── tools └── performance │ ├── python │ ├── performance.sh │ └── performance.py │ └── cpp │ ├── performance.sh │ └── performance.py ├── .gitignore ├── format.sh ├── decoding_attn ├── __init__.py └── decoding_attn_interface.py ├── csrc ├── kernel │ └── decoding_attn │ │ ├── decoding_fwd_hd64_hdv64.cu │ │ ├── decoding_fwd_hd96_hdv96.cu │ │ ├── decoding_fwd_hd576_hdv512.cu │ │ ├── decoding_fwd_hd256_hdv256.cu │ │ ├── decoding_fwd_hd128_hdv128.cu │ │ ├── decoding_fwd_launch_template.h │ │ ├── kernel_traits.h │ │ ├── decoding.h │ │ ├── block_info.h │ │ └── decoding_fwd_kernel.h ├── ops │ ├── decoding_attn.h │ └── decoding_attn.cpp ├── common │ ├── logging.h │ ├── cuda_timer.h │ ├── util.h │ ├── tensor.h │ ├── common.h │ └── tester.h └── torch │ └── decoding_torch.cpp ├── install_python.sh ├── .clang-format ├── run_python.sh ├── LICENSE ├── CMakeLists.txt ├── README.md ├── tests └── test_decoding_attn.py ├── setup.py ├── run_cpp.sh └── benchmarks ├── cpp └── benchmark_decoding_attn.cpp └── python └── benchmark_decoding_attn.py /media/images/dmha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/media/images/dmha.png -------------------------------------------------------------------------------- /performance/H20/seq_bandwidth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/H20/seq_bandwidth.png -------------------------------------------------------------------------------- /performance/H20/batch_bandwidth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/H20/batch_bandwidth.png -------------------------------------------------------------------------------- /performance/H20/seq_throughput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/H20/seq_throughput.png -------------------------------------------------------------------------------- /performance/H20/batch_throughput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/H20/batch_throughput.png -------------------------------------------------------------------------------- /performance/RTX3090/seq_throughput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/RTX3090/seq_throughput.png -------------------------------------------------------------------------------- /performance/RTX3090/batch_throughput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bruce-Lee-LY/decoding_attention/HEAD/performance/RTX3090/batch_throughput.png -------------------------------------------------------------------------------- /tools/performance/python/performance.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: performance python script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 12 | 13 | python3 $WORK_PATH/performance.py -f $WORK_PATH/../../../log/benchmark_decoding_attn.log 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .VSCodeCounter/ 3 | .idea/ 4 | 5 | __pycache__/ 6 | .pytest_cache/ 7 | go/ 8 | 9 | *~/ 10 | build/ 11 | install/ 12 | release/ 13 | output/ 14 | bin/ 15 | log/ 16 | model/ 17 | ncu/ 18 | nsys/ 19 | roofline/ 20 | ptx/ 21 | sass/ 22 | tmp/ 23 | temp/ 24 | 25 | dist/ 26 | *.egg-info/ 27 | *.eggs/ 28 | 29 | *.o 30 | *.so 31 | *.so.* 32 | *.out 33 | *.log 34 | *.bak 35 | *.pkz 36 | 37 | setting.h 38 | .config* 39 | -------------------------------------------------------------------------------- /tools/performance/cpp/performance.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: performance cpp script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 12 | 13 | python3 $WORK_PATH/performance.py -p $WORK_PATH/../../../log/decoding_seq/ 14 | python3 $WORK_PATH/performance.py -p $WORK_PATH/../../../log/decoding_batch/ 15 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:08:30 on Sun, Aug 27, 2023 4 | # 5 | # Description: format script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 12 | 13 | find . \( -name '*.c' -or -name '*.cpp' -or -name '*.cc' -or -name '*.cxx' -or -name '*.cu' -or -name '*.h' -or -name '*.hpp' -or -name '*.cuh' -or -name '*.inl' \) -exec clang-format -style=file -i {} \; 14 | -------------------------------------------------------------------------------- /decoding_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: decoding attn init 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | __version__ = "0.1.0" 16 | 17 | from decoding_attn.decoding_attn_interface import ( 18 | decoding_attn_fwd, 19 | ) 20 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_hd64_hdv64.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd hdim64 and hdimv64 6 | 7 | #include "decoding_attn/decoding_fwd_launch_template.h" 8 | 9 | template <> 10 | void run_dmha_fwd_(const DecodingParams ¶ms) { 11 | dmha_fwd(params); 12 | } 13 | 14 | template <> 15 | void run_dmha_fwd_<__nv_bfloat16, 64, 64>(const DecodingParams ¶ms) { 16 | dmha_fwd<__nv_bfloat16, 64, 64, 256, 4>(params); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_hd96_hdv96.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd hdim96 and hdimv96 6 | 7 | #include "decoding_attn/decoding_fwd_launch_template.h" 8 | 9 | template <> 10 | void run_dmha_fwd_(const DecodingParams ¶ms) { 11 | dmha_fwd(params); 12 | } 13 | 14 | template <> 15 | void run_dmha_fwd_<__nv_bfloat16, 96, 96>(const DecodingParams ¶ms) { 16 | dmha_fwd<__nv_bfloat16, 96, 96, 256, 4>(params); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_hd576_hdv512.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd hdim576 and hdimv512 6 | 7 | #include "decoding_attn/decoding_fwd_launch_template.h" 8 | 9 | template <> 10 | void run_dmha_fwd_(const DecodingParams ¶ms) { 11 | dmha_fwd(params); 12 | } 13 | 14 | template <> 15 | void run_dmha_fwd_<__nv_bfloat16, 576, 512>(const DecodingParams ¶ms) { 16 | dmha_fwd<__nv_bfloat16, 576, 512, 256, 8>(params); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_hd256_hdv256.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd hdim256 and hdimv256 6 | 7 | #include "decoding_attn/decoding_fwd_launch_template.h" 8 | 9 | template <> 10 | void run_dmha_fwd_(const DecodingParams ¶ms) { 11 | dmha_fwd(params); 12 | } 13 | 14 | template <> 15 | void run_dmha_fwd_<__nv_bfloat16, 256, 256>(const DecodingParams ¶ms) { 16 | dmha_fwd<__nv_bfloat16, 256, 256, 256, 16>(params); 17 | } 18 | -------------------------------------------------------------------------------- /install_python.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: insatll python script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | echo "========== intsall enter ==========" 12 | 13 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 14 | 15 | echo_cmd() { 16 | echo $1 17 | $1 18 | } 19 | 20 | echo "========== intsall decoding_attention ==========" 21 | 22 | echo_cmd "rm -rf build dist decoding_attn.egg-info" 23 | echo_cmd "python3 setup.py install --user --prefix=" 24 | 25 | echo "========== intsall exit ==========" 26 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | 3 | IndentWidth: 4 4 | 5 | TabWidth: 4 6 | 7 | UseTab: Never 8 | 9 | ObjCBlockIndentWidth: 4 10 | 11 | IndentCaseLabels: true 12 | 13 | IndentWrappedFunctionNames: true 14 | 15 | ColumnLimit: 120 16 | 17 | AccessModifierOffset: -4 18 | 19 | AllowShortFunctionsOnASingleLine: Empty 20 | 21 | AllowShortIfStatementsOnASingleLine: false 22 | 23 | AllowShortLoopsOnASingleLine: false 24 | 25 | AllowShortBlocksOnASingleLine: false 26 | 27 | AllowShortCaseLabelsOnASingleLine: false 28 | 29 | KeepEmptyLinesAtTheStartOfBlocks: false 30 | 31 | MaxEmptyLinesToKeep: 1 32 | 33 | DerivePointerAlignment: false 34 | 35 | PointerAlignment: Right 36 | 37 | Standard: Cpp11 38 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_hd128_hdv128.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd hdim128 and hdimv128 6 | 7 | #include "decoding_attn/decoding_fwd_launch_template.h" 8 | 9 | template <> 10 | void run_dmha_fwd_(const DecodingParams ¶ms) { 11 | if (params.b <= 4) { 12 | dmha_fwd(params); 13 | } else { 14 | dmha_fwd(params); 15 | } 16 | } 17 | 18 | template <> 19 | void run_dmha_fwd_<__nv_bfloat16, 128, 128>(const DecodingParams ¶ms) { 20 | if (params.b <= 4) { 21 | dmha_fwd<__nv_bfloat16, 128, 128, 256, 8>(params); 22 | } else { 23 | dmha_fwd<__nv_bfloat16, 128, 128, 128, 16>(params); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /csrc/ops/decoding_attn.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding attn 6 | 7 | #include 8 | 9 | #include "cuda_runtime_api.h" 10 | 11 | /** 12 | * @brief decoding attn api: sopprt MHA/MQA/GQA/MLA 13 | * 14 | * @param q [total_q * head_q * dim] 15 | * @param k [total_k * head_k * dim], kv_c_and_k_pe_cache for MLA 16 | * @param v [total_k * head_k * dim_v], nullptr for MLA 17 | * @param o [total_q * head_q * dim_v] 18 | * @param cu_seq_k [batch + 1] 19 | * @param max_seq_k 20 | * @param batch 21 | * @param head_q 22 | * @param head_k 23 | * @param dim 24 | * @param dim_v 25 | * @param is_alibi 26 | * @param is_bf16 27 | * @param stream 28 | */ 29 | void decoding_attn(void *q, void *k, void *v, void *o, int *cu_seq_k, size_t max_seq_k, size_t batch, size_t head_q, 30 | size_t head_k, size_t dim, size_t dim_v, bool is_alibi, bool is_bf16, cudaStream_t stream); 31 | -------------------------------------------------------------------------------- /run_python.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: run python script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 12 | 13 | export CUDA_VISIBLE_DEVICES=0 14 | # export CUDA_LAUNCH_BLOCKING=1 15 | 16 | rm -rf log && mkdir -p log 17 | 18 | # test 19 | python3 $WORK_PATH/tests/test_decoding_attn.py 20 | # pytest $WORK_PATH/tests/test_decoding_attn.py 21 | 22 | # FP16 23 | # python3 $WORK_PATH/benchmarks/python/benchmark_decoding_attn.py --head_q 32 --head_k 32 --dim 128 --dim_v 128 --warmup_iterations 1 --profiling_iterations 10 > log/benchmark_decoding_attn.log 2>&1 24 | 25 | # BF16 26 | # python3 $WORK_PATH/benchmarks/python/benchmark_decoding_attn.py --head_q 32 --head_k 32 --dim 128 --dim_v 128 --is_bf16 --warmup_iterations 1 --profiling_iterations 10 > log/benchmark_decoding_attn.log 2>&1 27 | 28 | # MLA 29 | # python3 $WORK_PATH/benchmarks/python/benchmark_decoding_attn.py --head_q 128 --head_k 1 --dim 576 --dim_v 512 --warmup_iterations 1 --profiling_iterations 10 > log/benchmark_decoding_attn.log 2>&1 30 | -------------------------------------------------------------------------------- /csrc/common/logging.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 20:42:28 on Sun, Feb 12, 2023 4 | // 5 | // Description: logging 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | inline char *curr_time() { 16 | time_t raw_time = time(nullptr); 17 | struct tm *time_info = localtime(&raw_time); 18 | static char now_time[64]; 19 | now_time[strftime(now_time, sizeof(now_time), "%Y-%m-%d %H:%M:%S", time_info)] = '\0'; 20 | 21 | return now_time; 22 | } 23 | 24 | inline int get_pid() { 25 | static int pid = getpid(); 26 | 27 | return pid; 28 | } 29 | 30 | inline long int get_tid() { 31 | thread_local long int tid = syscall(SYS_gettid); 32 | 33 | return tid; 34 | } 35 | 36 | #define DA_LOG_TAG "DA" 37 | #define DA_LOG_FILE(x) (strrchr(x, '/') ? (strrchr(x, '/') + 1) : x) 38 | #define DLOG(format, ...) \ 39 | do { \ 40 | fprintf(stderr, "[%s %s %d:%ld %s:%d %s] " format "\n", DA_LOG_TAG, curr_time(), get_pid(), get_tid(), \ 41 | DA_LOG_FILE(__FILE__), __LINE__, __FUNCTION__, ##__VA_ARGS__); \ 42 | } while (0) 43 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_launch_template.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd launch template 6 | 7 | #pragma once 8 | 9 | #include "decoding_attn/decoding_fwd_kernel.h" 10 | 11 | template 12 | void dmha_fwd(const DecodingParams ¶ms) { 13 | constexpr size_t warp_size = 32; 14 | constexpr size_t static_smem_size = ThreadsPerBlock / warp_size * sizeof(float); 15 | const size_t dynamic_smem_size = std::max(params.max_seq_k * sizeof(float), params.d_v * sizeof(float)); 16 | const size_t total_smem_size = static_smem_size + dynamic_smem_size; 17 | dim3 block(ThreadsPerBlock); 18 | dim3 grid(params.b, params.h); 19 | 20 | DA_BOOL_SWITCH(params.is_alibi, IsAlibi, [&] { 21 | auto kernel = 22 | &dmha_fwd_kernel, IsAlibi>; 23 | if (total_smem_size >= 48 * 1024) { 24 | DA_CHECK_CUDART_ERROR( 25 | cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, total_smem_size)); 26 | } 27 | 28 | kernel<<>>(params); 29 | DA_CHECK_CUDART_ERROR(cudaPeekAtLastError()); 30 | }); 31 | } 32 | -------------------------------------------------------------------------------- /csrc/common/cuda_timer.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 20:42:28 on Sun, Feb 12, 2023 4 | // 5 | // Description: cuda timer 6 | 7 | #pragma once 8 | 9 | #include "common.h" 10 | 11 | class CudaTimer { 12 | public: 13 | CudaTimer(cudaStream_t stream = nullptr) : m_stream(stream) { 14 | DA_CHECK_CUDART_ERROR(cudaEventCreate(&m_start)); 15 | DA_CHECK(m_start); 16 | DA_CHECK_CUDART_ERROR(cudaEventCreate(&m_end)); 17 | DA_CHECK(m_end); 18 | } 19 | 20 | ~CudaTimer() { 21 | if (m_start) { 22 | DA_CHECK_CUDART_ERROR(cudaEventDestroy(m_start)); 23 | m_start = nullptr; 24 | } 25 | 26 | if (m_end) { 27 | DA_CHECK_CUDART_ERROR(cudaEventDestroy(m_end)); 28 | m_end = nullptr; 29 | } 30 | } 31 | 32 | void start() { 33 | DA_CHECK_CUDART_ERROR(cudaEventRecord(m_start, m_stream)); 34 | } 35 | 36 | float end() { 37 | DA_CHECK_CUDART_ERROR(cudaEventRecord(m_end, m_stream)); 38 | DA_CHECK_CUDART_ERROR(cudaEventSynchronize(m_end)); 39 | DA_CHECK_CUDART_ERROR(cudaEventElapsedTime(&m_elapsed_time, m_start, m_end)); 40 | 41 | return m_elapsed_time; 42 | } 43 | 44 | private: 45 | const cudaStream_t m_stream = nullptr; 46 | 47 | cudaEvent_t m_start = nullptr; 48 | cudaEvent_t m_end = nullptr; 49 | float m_elapsed_time = 0.0; 50 | 51 | DA_DISALLOW_COPY_AND_ASSIGN(CudaTimer); 52 | }; 53 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/kernel_traits.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: kernel traits 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | template 12 | struct DecodingKernelTraits { 13 | static constexpr size_t head_dim = HeadDim; 14 | static constexpr size_t head_dim_v = HeadDimV; 15 | static constexpr size_t threads_per_block = ThreadsPerBlock; 16 | static constexpr size_t threads_per_group = ThreadsPerGroup; 17 | 18 | static constexpr size_t warp_size = 32; 19 | static constexpr size_t warps_per_block = threads_per_block / warp_size; 20 | 21 | static constexpr size_t groups_per_warp = warp_size / threads_per_group; 22 | static constexpr size_t groups_per_block = groups_per_warp * warps_per_block; 23 | 24 | static constexpr size_t thread_copy_bytes = 16; 25 | static constexpr size_t thread_copy_elem_nums = thread_copy_bytes / sizeof(T); 26 | 27 | static constexpr size_t thread_qk_nums = head_dim / threads_per_group; 28 | static constexpr size_t thread_copy_qk_iters = thread_qk_nums / thread_copy_elem_nums; 29 | 30 | static constexpr size_t thread_vo_nums = head_dim_v / threads_per_group; 31 | static constexpr size_t thread_copy_vo_iters = thread_vo_nums / thread_copy_elem_nums; 32 | 33 | static constexpr unsigned int shfl_mask = 0xffffffff; 34 | }; 35 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding 6 | 7 | #pragma once 8 | 9 | #include "common.h" 10 | 11 | struct DecodingParams { 12 | // The QKV matrices. 13 | void *__restrict__ q_ptr; 14 | void *__restrict__ k_ptr; 15 | void *__restrict__ v_ptr; 16 | 17 | // The stride between rows of the Q, K and V matrices. 18 | size_t q_row_stride; 19 | size_t k_row_stride; 20 | size_t v_row_stride; 21 | size_t q_head_stride; 22 | size_t k_head_stride; 23 | size_t v_head_stride; 24 | 25 | // The number of heads. 26 | int h, h_k; 27 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 28 | // different from nheads (query). 29 | int h_h_k_ratio; // precompute h / h_k, 30 | 31 | // The O matrix (output). 32 | void *__restrict__ o_ptr; 33 | 34 | // The stride between rows of O. 35 | size_t o_row_stride; 36 | size_t o_head_stride; 37 | 38 | // The dimensions. 39 | int b, max_seq_k, d, d_v; 40 | 41 | // The scaling factors for the kernel. 42 | float scale_softmax; 43 | 44 | // array of length b+1 holding starting offset of each sequence. 45 | int *__restrict__ cu_seq_k; 46 | 47 | bool is_alibi; 48 | bool is_bf16; 49 | 50 | cudaStream_t stream; 51 | }; 52 | 53 | template 54 | void run_dmha_fwd_(const DecodingParams ¶ms); 55 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/block_info.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: block info 6 | 7 | #pragma once 8 | 9 | #include "cuda_runtime_api.h" 10 | 11 | struct DecodingBlockInfo { 12 | template 13 | __device__ DecodingBlockInfo(const Params ¶ms, const int bidb, const int bidh) 14 | : b(bidb), 15 | h(bidh), 16 | h_k(h / params.h_h_k_ratio), 17 | sum_seq_q(b), 18 | sum_seq_k(params.cu_seq_k[b]), 19 | actual_seq_k(params.cu_seq_k[b + 1] - sum_seq_k), 20 | row_shift(actual_seq_k - actual_seq_q), 21 | h_slope(1.0 / exp2f(8.0 * (h + 1) / params.h)) {} 22 | 23 | inline __device__ size_t q_offset(const size_t row_stride, const size_t head_stride, const size_t dim_idx) const { 24 | return static_cast(sum_seq_q) * row_stride + h * head_stride + dim_idx; 25 | } 26 | 27 | inline __device__ size_t k_offset(const size_t seq_k, const size_t row_stride, const size_t head_stride, 28 | const size_t dim_idx) const { 29 | return static_cast(sum_seq_k + seq_k) * row_stride + h_k * head_stride + dim_idx; 30 | } 31 | 32 | const int b; 33 | const int h; 34 | const int h_k; 35 | const int sum_seq_q; 36 | const int sum_seq_k; 37 | const int actual_seq_k; 38 | const int row_shift; 39 | const float h_slope; 40 | 41 | const int actual_seq_q = 1; 42 | }; 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Bruce-Lee-LY 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /decoding_attn/decoding_attn_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: decoding attn interface 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | import torch 16 | 17 | import decoding_attn_cuda 18 | 19 | 20 | def maybe_contiguous(x): 21 | return x.contiguous() if x is not None and x.stride(-1) != 1 else x 22 | 23 | 24 | def decoding_attn_fwd( 25 | q: torch.tensor, 26 | k: torch.tensor, 27 | v: torch.tensor, 28 | cu_seq_k: torch.tensor, 29 | max_seq_k: int, 30 | dim_v: int, 31 | is_alibi: bool 32 | ) -> torch.tensor: 33 | """ 34 | Arguments: 35 | q: [total_q, head_q, dim], torch.float16 / torch.bfloat16, where total_q = total number of 36 | query tokens in the batch. 37 | k: [total_k, head_k, dim], torch.float16 / torch.bfloat16, where total_k = total number of 38 | key tokens in the batch, MLA: kv_c_and_k_pe_cache. 39 | v: [total_k, head_k, dim_v], torch.float16 / torch.bfloat16, where total_k = total number of 40 | value tokens in the batch, MLA: None. 41 | cu_seq_k: [batch + 1], dtype torch.int32. The cumulative sequence lengths of the sequences 42 | in the batch, used to index into k / v. 43 | max_seq_k: Maximum key sequence length in the batch. 44 | dim_v: Head dimension of v. 45 | is_alibi: Whether to apply alibi. 46 | Return: 47 | o: [total_q, head_q, dim_v], torch.float16 / torch.bfloat16, where total_q = total number of 48 | output tokens in the batch. 49 | """ 50 | q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 51 | o = decoding_attn_cuda.fwd( 52 | q, k, v, None, cu_seq_k, max_seq_k, dim_v, is_alibi) 53 | return o 54 | -------------------------------------------------------------------------------- /csrc/ops/decoding_attn.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding attn 6 | 7 | #include "decoding_attn.h" 8 | 9 | #include 10 | 11 | #include "decoding_attn/decoding.h" 12 | 13 | void set_dmha_fwd_params(DecodingParams ¶ms, void *q, void *k, void *v, void *o, int *cu_seq_k, size_t max_seq_k, 14 | size_t batch, size_t head_q, size_t head_k, size_t dim, size_t dim_v, bool is_alibi, 15 | bool is_bf16, cudaStream_t stream) { 16 | DA_CHECK(q); 17 | DA_CHECK(k); 18 | DA_CHECK(o); 19 | DA_CHECK(cu_seq_k); 20 | DA_CHECK_EQ(head_q % head_k, 0); 21 | DA_CHECK_LE(dim, 576); 22 | DA_CHECK_LE(dim_v, 512); 23 | 24 | // Reset the parameters 25 | memset(¶ms, 0, sizeof(params)); 26 | 27 | // Set the pointers and strides. 28 | params.q_ptr = q; 29 | params.k_ptr = k; 30 | params.v_ptr = v ? v : k; 31 | 32 | params.q_row_stride = head_q * dim; 33 | params.k_row_stride = head_k * dim; 34 | params.v_row_stride = head_k * dim; 35 | params.q_head_stride = dim; 36 | params.k_head_stride = dim; 37 | params.v_head_stride = dim_v; 38 | 39 | params.h = head_q; 40 | params.h_k = head_k; 41 | params.h_h_k_ratio = params.h / params.h_k; 42 | 43 | params.o_ptr = o; 44 | 45 | params.o_row_stride = head_q * dim_v; 46 | params.o_head_stride = dim_v; 47 | 48 | // Set the dimensions. 49 | params.b = batch; 50 | params.max_seq_k = max_seq_k; 51 | params.d = dim; 52 | params.d_v = dim_v; 53 | 54 | params.scale_softmax = 1.0 / std::sqrt(dim); 55 | 56 | params.cu_seq_k = cu_seq_k; 57 | 58 | params.is_alibi = is_alibi; 59 | params.is_bf16 = is_bf16; 60 | 61 | params.stream = stream; 62 | } 63 | 64 | void run_dmha_fwd(const DecodingParams ¶ms) { 65 | DA_FP16_SWITCH(!params.is_bf16, [&] { 66 | DA_HEADDIM_SWITCH(params.d, params.d_v, [&] { run_dmha_fwd_(params); }); 67 | }); 68 | } 69 | 70 | void decoding_attn(void *q, void *k, void *v, void *o, int *cu_seq_k, size_t max_seq_k, size_t batch, size_t head_q, 71 | size_t head_k, size_t dim, size_t dim_v, bool is_alibi, bool is_bf16, cudaStream_t stream) { 72 | DecodingParams params; 73 | set_dmha_fwd_params(params, q, k, v, o, cu_seq_k, max_seq_k, batch, head_q, head_k, dim, dim_v, is_alibi, is_bf16, 74 | stream); 75 | run_dmha_fwd(params); 76 | } 77 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: cmake for cpp decoding attention 6 | 7 | cmake_minimum_required (VERSION 3.16) 8 | 9 | project (decoding_attention LANGUAGES C CXX CUDA) 10 | 11 | if (POLICY CMP0146) 12 | cmake_policy (SET CMP0146 OLD) 13 | endif () 14 | 15 | find_program (CCACHE_FOUND ccache) 16 | if (CCACHE_FOUND) 17 | set_property (GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) 18 | set_property (GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) 19 | endif (CCACHE_FOUND) 20 | 21 | set (CMAKE_VERBOSE_MAKEFILE ${DA_VERBOSE_MAKEFILE}) 22 | 23 | set (CMAKE_C_FLAGS "-std=c17") 24 | set (CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -O0 -g2 -ggdb") 25 | set (CMAKE_C_FLAGS_RELEASE "$ENV{CFLAGS} -O3") 26 | 27 | set (CMAKE_CXX_FLAGS "-std=c++17") 28 | set (CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb") 29 | set (CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3") 30 | 31 | set (CMAKE_EXE_LINKER_FLAGS "-Wl,--as-needed") 32 | 33 | add_compile_options ( 34 | -Wall 35 | -Werror 36 | -Wextra 37 | -Wswitch-default 38 | # -Wfloat-equal 39 | -Wshadow 40 | -Wcast-qual 41 | ) 42 | 43 | # Nvidia GPU 44 | find_package (CUDA REQUIRED) 45 | # unset (CUDA_USE_STATIC_CUDA_RUNTIME CACHE) 46 | # option (CUDA_USE_STATIC_CUDA_RUNTIME OFF) 47 | 48 | set (CUDA_VERBOSE_BUILD ${DA_VERBOSE_MAKEFILE}) 49 | set (CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++17 -Xcompiler -fopenmp --expt-relaxed-constexpr") 50 | if (${CMAKE_BUILD_TYPE} MATCHES "Debug") 51 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -arch=sm_${CMAKE_CUDA_ARCHITECTURES} -g -lineinfo -Xptxas=-v -O0") 52 | else () 53 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_${CMAKE_CUDA_ARCHITECTURES},code=sm_${CMAKE_CUDA_ARCHITECTURES} --use_fast_math -O3") 54 | endif () 55 | 56 | set (SYS_CUDART_PATH "/usr/local/cuda") 57 | set (SYS_CUDA_DRIVER_PATH "/usr/lib/x86_64-linux-gnu") 58 | 59 | find_package(gflags REQUIRED) 60 | find_package(OpenMP REQUIRED) 61 | 62 | include_directories ( 63 | ${PROJECT_SOURCE_DIR}/csrc/common 64 | ${PROJECT_SOURCE_DIR}/csrc/kernel 65 | ${PROJECT_SOURCE_DIR}/csrc/ops 66 | ${SYS_CUDART_PATH}/include 67 | ${GFLAGS_INCLUDE_DIR} 68 | ) 69 | 70 | link_directories ( 71 | ${SYS_CUDART_PATH}/lib64 72 | ${SYS_CUDA_DRIVER_PATH} 73 | ) 74 | 75 | file (GLOB DA_SRCS 76 | ${PROJECT_SOURCE_DIR}/csrc/kernel/decoding_attn/*.cu 77 | ${PROJECT_SOURCE_DIR}/csrc/ops/*.cpp 78 | ${PROJECT_SOURCE_DIR}/benchmarks/cpp/*.cpp 79 | ) 80 | 81 | cuda_add_executable (benchmark_decoding_attn ${DA_SRCS}) 82 | target_link_libraries (benchmark_decoding_attn OpenMP::OpenMP_CXX ${GFLAGS_LIBRARIES}) 83 | 84 | install (TARGETS benchmark_decoding_attn RUNTIME DESTINATION bin) 85 | -------------------------------------------------------------------------------- /csrc/common/util.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 20:42:28 on Sun, Feb 12, 2023 4 | // 5 | // Description: util function 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "cuda_runtime_api.h" 13 | 14 | inline __device__ __host__ size_t div_ceil(size_t a, size_t b) { 15 | return (a % b != 0) ? (a / b + 1) : (a / b); 16 | } 17 | 18 | // Beginning of GPU Architecture definitions 19 | inline int convert_SM_to_cores(int major, int minor) { 20 | // Defines for GPU Architecture types (using the SM version to determine the # of cores per SM 21 | typedef struct { 22 | int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM minor version 23 | int cores; 24 | } sSMtoCores; 25 | 26 | sSMtoCores nGpuArchCoresPerSM[] = {{0x30, 192}, {0x32, 192}, {0x35, 192}, {0x37, 192}, {0x50, 128}, 27 | {0x52, 128}, {0x53, 128}, {0x60, 64}, {0x61, 128}, {0x62, 128}, 28 | {0x70, 64}, {0x72, 64}, {0x75, 64}, {0x80, 64}, {0x86, 128}, 29 | {0x87, 128}, {0x89, 128}, {0x90, 128}, {-1, -1}}; 30 | 31 | int index = 0; 32 | 33 | while (nGpuArchCoresPerSM[index].SM != -1) { 34 | if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { 35 | return nGpuArchCoresPerSM[index].cores; 36 | } 37 | 38 | index++; 39 | } 40 | 41 | // If we don't find the values, we default use the previous one to run properly 42 | DLOG("MapSMtoCores for SM %d.%d is undefined. Default to use %d cores/SM", major, minor, 43 | nGpuArchCoresPerSM[index - 1].cores); 44 | 45 | return nGpuArchCoresPerSM[index - 1].cores; 46 | } 47 | 48 | inline const char *convert_SM_to_arch_name(int major, int minor) { 49 | // Defines for GPU Architecture types (using the SM version to determine the GPU Arch name) 50 | typedef struct { 51 | int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM minor version 52 | const char *name; 53 | } sSMtoArchName; 54 | 55 | sSMtoArchName nGpuArchNameSM[] = {{0x30, "Kepler"}, {0x32, "Kepler"}, {0x35, "Kepler"}, {0x37, "Kepler"}, 56 | {0x50, "Maxwell"}, {0x52, "Maxwell"}, {0x53, "Maxwell"}, {0x60, "Pascal"}, 57 | {0x61, "Pascal"}, {0x62, "Pascal"}, {0x70, "Volta"}, {0x72, "Xavier"}, 58 | {0x75, "Turing"}, {0x80, "Ampere"}, {0x86, "Ampere"}, {0x87, "Ampere"}, 59 | {0x89, "Ada"}, {0x90, "Hopper"}, {-1, "Graphics Device"}}; 60 | 61 | int index = 0; 62 | 63 | while (nGpuArchNameSM[index].SM != -1) { 64 | if (nGpuArchNameSM[index].SM == ((major << 4) + minor)) { 65 | return nGpuArchNameSM[index].name; 66 | } 67 | 68 | index++; 69 | } 70 | 71 | // If we don't find the values, we default use the previous one to run properly 72 | DLOG("MapSMtoArchName for SM %d.%d is undefined. Default to use %s", major, minor, nGpuArchNameSM[index - 1].name); 73 | 74 | return nGpuArchNameSM[index - 1].name; 75 | } 76 | -------------------------------------------------------------------------------- /csrc/torch/decoding_torch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding torch 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "decoding_attn.h" 13 | 14 | #define DA_CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") 15 | #define DA_CHECK_SHAPE(x, ...) \ 16 | TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 17 | #define DA_CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | 19 | at::Tensor dmha_fwd(const at::Tensor &q, const at::Tensor &k, std::optional &v_, 20 | c10::optional &o_, const at::Tensor &cu_seq_k, const int max_seq_k, const int dim_v, 21 | const bool is_alibi) { 22 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 23 | 24 | auto q_dtype = q.dtype(); 25 | TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, 26 | "Decoding-Attention only support FP16 and BF16 data type"); 27 | bool is_bf16 = q_dtype == torch::kBFloat16; 28 | TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); 29 | TORCH_CHECK(cu_seq_k.dtype() == torch::kInt32, "cu_seq_k must have dtype int32"); 30 | 31 | DA_CHECK_DEVICE(q); 32 | DA_CHECK_DEVICE(k); 33 | DA_CHECK_DEVICE(cu_seq_k); 34 | 35 | TORCH_CHECK(q.stride(-1) == 1, "query must have contiguous last dimension"); 36 | TORCH_CHECK(k.stride(-1) == 1, "key must have contiguous last dimension"); 37 | DA_CHECK_CONTIGUOUS(cu_seq_k); 38 | 39 | const int batch = cu_seq_k.numel() - 1; 40 | const int total_q = q.size(0); 41 | const int head_q = q.size(1); 42 | const int dim = q.size(2); 43 | const int total_k = k.size(0); 44 | const int head_k = k.size(1); 45 | TORCH_CHECK(batch > 0, "batch size must be positive"); 46 | TORCH_CHECK(dim <= 576, "dim should be less than 256"); 47 | TORCH_CHECK(dim_v <= 512, "dim_v should be less than 256"); 48 | TORCH_CHECK(head_q % head_k == 0, "number of heads in key/value must divide number of heads in query"); 49 | 50 | DA_CHECK_SHAPE(q, total_q, head_q, dim); 51 | DA_CHECK_SHAPE(k, total_k, head_k, dim); 52 | DA_CHECK_SHAPE(cu_seq_k, batch + 1); 53 | 54 | void *v_ptr = nullptr; 55 | if (v_.has_value()) { 56 | at::Tensor v = v_.value(); 57 | TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); 58 | DA_CHECK_DEVICE(v); 59 | TORCH_CHECK(v.stride(-1) == 1, "value must have contiguous last dimension"); 60 | DA_CHECK_SHAPE(v, total_k, head_k, dim_v); 61 | v_ptr = v.data_ptr(); 62 | } 63 | 64 | at::Tensor o; 65 | if (o_.has_value()) { 66 | o = o_.value(); 67 | TORCH_CHECK(o.dtype() == q_dtype, "Output must have the same dtype as inputs"); 68 | DA_CHECK_DEVICE(o); 69 | TORCH_CHECK(o.stride(-1) == 1, "Output tensor must have contiguous last dimension"); 70 | DA_CHECK_SHAPE(o, total_q, head_q, dim_v); 71 | } else { 72 | auto opts = q.options(); 73 | o = torch::empty({total_q, head_q, dim_v}, opts); 74 | } 75 | 76 | decoding_attn(q.data_ptr(), k.data_ptr(), v_ptr, o.data_ptr(), reinterpret_cast(cu_seq_k.data_ptr()), 77 | max_seq_k, batch, head_q, head_k, dim, dim_v, is_alibi, is_bf16, stream); 78 | 79 | return o; 80 | } 81 | 82 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 83 | m.doc() = "Decoding-Attention"; 84 | m.def("fwd", &dmha_fwd, "Forward pass"); 85 | } 86 | -------------------------------------------------------------------------------- /csrc/common/tensor.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:08:30 on Sun, Aug 27, 2023 4 | // 5 | // Description: tensor 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "common.h" 12 | 13 | template 14 | class Tensor { 15 | public: 16 | Tensor(const std::vector &shape, const std::string &name = "Tensor", float min = -1.0, float max = 1.0) 17 | : m_shape(shape), m_name(name), m_min(min), m_max(max) { 18 | DA_CHECK_GT(shape.size(), 0); 19 | for (size_t i = 0; i < shape.size(); ++i) { 20 | DA_CHECK_GT(shape[i], 0); 21 | } 22 | 23 | m_elem_num = std::accumulate(m_shape.begin(), m_shape.end(), 1, std::multiplies()); 24 | DA_CHECK_GT(m_elem_num, 0); 25 | 26 | m_host_ptr = new T[m_elem_num]; 27 | DA_CHECK(m_host_ptr); 28 | DA_CHECK_CUDART_ERROR(cudaMalloc((void **)&m_dev_ptr, m_elem_num * sizeof(T))); 29 | DA_CHECK(m_dev_ptr); 30 | 31 | std::random_device rd; 32 | std::default_random_engine engine{rd()}; 33 | std::uniform_real_distribution uniform(m_min, m_max); 34 | for (size_t i = 0; i < m_elem_num; ++i) { 35 | m_host_ptr[i] = static_cast(uniform(engine)); 36 | } 37 | 38 | DA_CHECK_CUDART_ERROR(cudaMemcpy(m_dev_ptr, m_host_ptr, m_elem_num * sizeof(T), cudaMemcpyHostToDevice)); 39 | 40 | DLOG("%s: %zu, cpu: %p, gpu: %p", m_name.c_str(), m_elem_num, m_host_ptr, m_dev_ptr); 41 | } 42 | 43 | ~Tensor() { 44 | if (m_host_ptr) { 45 | delete[] m_host_ptr; 46 | m_host_ptr = nullptr; 47 | } 48 | 49 | if (m_dev_ptr) { 50 | DA_CHECK_CUDART_ERROR(cudaFree((void *)m_dev_ptr)); 51 | m_dev_ptr = nullptr; 52 | } 53 | } 54 | 55 | std::vector getShape() const { 56 | return m_shape; 57 | } 58 | 59 | size_t getElemNum() const { 60 | return m_elem_num; 61 | } 62 | 63 | T *getHostPtr() const { 64 | return m_host_ptr; 65 | } 66 | 67 | T *getDevPtr() const { 68 | return m_dev_ptr; 69 | } 70 | 71 | void tearUp(Tensor *base) { 72 | DA_CHECK(base); 73 | DA_CHECK_EQ(m_elem_num, base->getElemNum()); 74 | 75 | DA_CHECK_CUDART_ERROR( 76 | cudaMemcpy(m_dev_ptr, base->getDevPtr(), m_elem_num * sizeof(T), cudaMemcpyDeviceToDevice)); 77 | } 78 | 79 | void moveToHost() { 80 | DA_CHECK_CUDART_ERROR(cudaMemcpy(m_host_ptr, m_dev_ptr, m_elem_num * sizeof(T), cudaMemcpyDeviceToHost)); 81 | } 82 | 83 | void moveToDevice() { 84 | DA_CHECK_CUDART_ERROR(cudaMemcpy(m_dev_ptr, m_host_ptr, m_elem_num * sizeof(T), cudaMemcpyHostToDevice)); 85 | } 86 | 87 | void memSetHost() { 88 | memset(m_host_ptr, 0, m_elem_num * sizeof(T)); 89 | } 90 | 91 | void memSetDevice() { 92 | DA_CHECK_CUDART_ERROR(cudaMemset(m_dev_ptr, 0, m_elem_num * sizeof(T))); 93 | } 94 | 95 | void checkValue(Tensor *base) { 96 | DA_CHECK(base); 97 | DA_CHECK_EQ(m_elem_num, base->getElemNum()); 98 | 99 | m_max_diff = 0.0; 100 | m_avg_diff = 0.0; 101 | double diff = 0.0; 102 | for (size_t i = 0; i < m_elem_num; ++i) { 103 | diff = static_cast( 104 | std::abs(static_cast(m_host_ptr[i]) - static_cast(base->getHostPtr()[i]))); 105 | m_max_diff = std::max(m_max_diff, diff); 106 | m_avg_diff += diff; 107 | } 108 | 109 | m_avg_diff /= static_cast(m_elem_num); 110 | 111 | DLOG("Max diff: %f, avg diff: %f", m_max_diff, m_avg_diff); 112 | } 113 | 114 | private: 115 | const std::vector m_shape; 116 | const std::string m_name = "Tensor"; 117 | // the threshold of the random tensor will affect the difference of the dmha results 118 | const float m_min = -1.0; 119 | const float m_max = 1.0; 120 | 121 | size_t m_elem_num = 0; 122 | T *m_host_ptr = nullptr; 123 | T *m_dev_ptr = nullptr; 124 | 125 | double m_max_diff = 0.0; 126 | double m_avg_diff = 0.0; 127 | 128 | DA_DISALLOW_COPY_AND_ASSIGN(Tensor); 129 | }; 130 | -------------------------------------------------------------------------------- /tools/performance/python/performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: performance python line chart 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | import os 16 | import optparse 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def process_log(file): 22 | seq_data = [] 23 | batch_data = [] 24 | with open(file) as fp: 25 | line = fp.readline() 26 | while line and 'Benchmark Batch' not in line: 27 | if 'TFLOPS' in line: 28 | seq_data.append(line) 29 | 30 | line = fp.readline() 31 | 32 | while line: 33 | if 'TFLOPS' in line: 34 | batch_data.append(line) 35 | 36 | line = fp.readline() 37 | 38 | return seq_data, batch_data 39 | 40 | 41 | def draw_line_chart(methods, dims, data, figure_name, y_step, x_label, y_label, title): 42 | fig = plt.figure(figsize=(32, 24), dpi=100) 43 | 44 | dims_str = list(map(str, dims)) 45 | 46 | colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] 47 | linestyles = ['-', '--', '-.', ':'] 48 | 49 | for i in range(len(methods)): 50 | plt.plot(dims_str, data[i], color=colors[i % len(colors)], 51 | linestyle=linestyles[(i // len(colors)) % len(linestyles)], marker='o', markersize=6) 52 | 53 | # plt.xticks(dims) 54 | plt.ylim(bottom=0) 55 | plt.yticks( 56 | range(0, round(np.max(np.max(data, axis=0)) + 0.5) + y_step, y_step)) 57 | plt.tick_params(labelsize=25) 58 | 59 | # plt.hlines(y=100, xmin=dims_str[0], xmax=dims_str[-1], colors='r', linestyles='-.') 60 | plt.grid(True, linestyle='-.') 61 | 62 | plt.xlabel(x_label, fontdict={'size': '30'}) 63 | plt.ylabel(y_label, fontdict={'size': '30'}) 64 | plt.title(title, fontdict={'size': '30'}) 65 | plt.legend(methods, loc='best', prop={'size': '30'}) 66 | 67 | plt.savefig(figure_name, dpi=fig.dpi) 68 | # plt.show() 69 | 70 | 71 | def analyze_data(data, dim_idx, data_path, data_type, x_label): 72 | methods = [] 73 | dims = [] 74 | for it in data: 75 | iterms = it.split(' ') 76 | 77 | method = iterms[0] 78 | if method not in methods: 79 | methods.append(method) 80 | 81 | params = iterms[1].split('-') 82 | dim = int(params[dim_idx]) 83 | if dim not in dims: 84 | dims.append(dim) 85 | 86 | dims.sort() 87 | 88 | throughputs = np.zeros((len(methods), len(dims)), np.float64) 89 | bandwidths = np.zeros((len(methods), len(dims)), np.float64) 90 | for it in data: 91 | iterms = it.split(' ') 92 | method = iterms[0] 93 | params = iterms[1].split('-') 94 | dim = int(params[dim_idx]) 95 | 96 | throughputs[methods.index( 97 | method)][dims.index(dim)] = float(iterms[7]) 98 | 99 | bandwidths[methods.index( 100 | method)][dims.index(dim)] = float(iterms[10]) 101 | 102 | draw_line_chart(methods, dims, throughputs, data_path + data_type + 103 | '_throughput.png', 20, x_label, 'Throughput / TFLOPS', 'Decoding Attention Throughput') 104 | 105 | draw_line_chart(methods, dims, bandwidths, data_path + data_type + 106 | '_bandwidth.png', 100, x_label, 'Bandwidth / GB/s', 'Decoding Attention Bandwidth') 107 | 108 | 109 | def main(): 110 | usage = "python3 performance.py -f log/benchmark_decoding_attn.log" 111 | parser = optparse.OptionParser(usage) 112 | parser.add_option('-f', '--file', dest='file', 113 | type='string', help='file name', default='log/benchmark_decoding_attn.log') 114 | 115 | options, args = parser.parse_args() 116 | file = options.file 117 | 118 | seq_data, batch_data = process_log(file) 119 | data_path = os.path.dirname(file) 120 | analyze_data(seq_data, 1, data_path, '/seq', 'Seq Len') 121 | analyze_data(batch_data, 0, data_path, '/batch', 'Batch Size') 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decoding Attention 2 | Decoding Attention is specially optimized for Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA) and Multi-Head Latent Attention (MLA) using CUDA core for the decoding stage of LLM inference. It mainly refers to OpenPPL and Flash Attention, which can solve the problem of low tensor core utilization of Flash Attention in the decoding stage of LLM inference and support more types of attention and kv cache quantization optimization. The calculation expression is as follows, where the precision of tensor Q, K, V and O is FP16 or BF16. In some LLM inference decoding scenarios, the performance of Decoding Attention is better than Flash Decoding (Flash Attention) and FlashInfer. In addition, Decoding Attention also supports variable length and ALiBi inference scenarios. 3 | ``` 4 | O = Softmax(Q * K^T) * V 5 | ``` 6 | 7 | ![dmha](./media/images/dmha.png) 8 | 9 | # Support 10 | - Variable Length: Variable kv length inference 11 | - ALiBi: Attention with linear biases inference 12 | 13 | # Environment 14 | - OS: Linux 15 | - Cmake Version: >= 3.16 16 | - GCC Version: >= 5.0 17 | - CUDA Version: >= 11.4 18 | - Others: gflags, ccache, pytest 19 | ``` 20 | sudo apt-get install libgflags-dev ccache 21 | pip install pytest 22 | ``` 23 | 24 | # Clone 25 | ``` 26 | git clone https://github.com/Bruce-Lee-LY/decoding_attention.git 27 | ``` 28 | 29 | # CPP API 30 | ## Build 31 | ### NVIDIA A100 32 | ``` 33 | cd decoding_attention 34 | ./build_cpp.sh -a 80 -t Release -b OFF 35 | ./build_cpp.sh -a 80 -t Debug -b OFF 36 | ``` 37 | 38 | ### RTX3080Ti / RTX3090 / RTX A6000 39 | ``` 40 | cd decoding_attention 41 | ./build_cpp.sh -a 86 -t Release -b OFF 42 | ./build_cpp.sh -a 86 -t Debug -b OFF 43 | ``` 44 | 45 | ### L20 / L40S 46 | ``` 47 | cd decoding_attention 48 | ./build_cpp.sh -a 89 -t Release -b OFF 49 | ./build_cpp.sh -a 89 -t Debug -b OFF 50 | ``` 51 | 52 | ### H20 / H800 53 | ``` 54 | cd decoding_attention 55 | ./build_cpp.sh -a 90 -t Release -b OFF 56 | ./build_cpp.sh -a 90 -t Debug -b OFF 57 | ``` 58 | 59 | ## Test 60 | ``` 61 | ./run_cpp.sh 62 | ``` 63 | 64 | ## Benchmark 65 | ``` 66 | ./run_cpp.sh 67 | ``` 68 | 69 | ## Performance 70 | Process the cpp result in the log and plot it as a line chart. 71 | 72 | ``` 73 | cd tools/performance/cpp 74 | ./performance.sh 75 | ``` 76 | 77 | # Python API 78 | ## Install 79 | ``` 80 | cd decoding_attention 81 | ./install_python.sh 82 | ``` 83 | 84 | ## Test 85 | ``` 86 | ./run_python.sh 87 | ``` 88 | 89 | ## Benchmark 90 | ``` 91 | ./run_python.sh 92 | ``` 93 | 94 | ## Performance 95 | Process the python result in the log and plot it as a line chart. 96 | 97 | ``` 98 | cd tools/performance/python 99 | ./performance.sh 100 | ``` 101 | 102 | ### MHA Running on RTX3090 103 | - CUDA Version: 12.1 104 | - Head Num: 32 105 | - Head Dim: 128 106 | - Data Type: FP16 107 | 108 | #### Seq Len 109 | The performance of Decoding Attention is better when the sequence length is below 1536, while the performance of Flash Decoding (Flash Attention) and FlashInfer is better when the sequence length is above 1536. 110 | - Batch Size: 1 111 | - Seq Q: 1 112 | - Seq K: Seq Len 113 | 114 | ![seq_throughput](./performance/RTX3090/seq_throughput.png) 115 | 116 | #### Batch Size 117 | Regardless of bacth size, Decoding Attention has better performance than Flash Decoding (Flash Attention) and FlashInfer. 118 | - Batch Size: Batch Size 119 | - Seq Q: 1 120 | - Seq K: 128 121 | 122 | ![batch_throughput](./performance/RTX3090/batch_throughput.png) 123 | 124 | ### MLA Running on H20 125 | - CUDA Version: 12.4 126 | - Head Num: 128 127 | - Head Num K: 1 128 | - Head Dim: 576 129 | - Head Dim V: 512 130 | - Data Type: FP16 131 | 132 | #### Seq Len 133 | - Batch Size: 1 134 | - Seq Q: 1 135 | - Seq K: Seq Len 136 | 137 | ![seq_throughput](./performance/H20/seq_throughput.png) 138 | ![seq_bandwidth](./performance/H20/seq_bandwidth.png) 139 | 140 | #### Batch Size 141 | - Batch Size: Batch Size 142 | - Seq Q: 1 143 | - Seq K: 4096 144 | 145 | ![batch_throughput](./performance/H20/batch_throughput.png) 146 | ![batch_bandwidth](./performance/H20/batch_bandwidth.png) 147 | 148 | # Reference 149 | - [ppl.llm.kernel.cuda](https://github.com/OpenPPL/ppl.llm.kernel.cuda) 150 | - [flash-attention](https://github.com/Dao-AILab/flash-attention): v2.6.3 151 | - [flashinfer](https://github.com/Bruce-Lee-LY/flashinfer): v0.1.6 152 | - [FlashMLA](https://github.com/deepseek-ai/FlashMLA) 153 | 154 | # TODO 155 | - Kernel Optimization 156 | - KV Cache Quantization: FP8、Int8、Int4 157 | -------------------------------------------------------------------------------- /csrc/common/common.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 20:42:28 on Sun, Feb 12, 2023 4 | // 5 | // Description: common macro 6 | 7 | #pragma once 8 | 9 | #include "cuda_bf16.h" 10 | #include "cuda_fp16.h" 11 | #include "logging.h" 12 | #include "util.h" 13 | 14 | #define DA_LIKELY(x) __builtin_expect(!!(x), 1) 15 | #define DA_UNLIKELY(x) __builtin_expect(!!(x), 0) 16 | 17 | #define DA_CHECK(x) \ 18 | do { \ 19 | if (DA_UNLIKELY(!(x))) { \ 20 | DLOG("Check failed: %s", #x); \ 21 | exit(EXIT_FAILURE); \ 22 | } \ 23 | } while (0) 24 | 25 | #define DA_CHECK_EQ(x, y) DA_CHECK((x) == (y)) 26 | #define DA_CHECK_NE(x, y) DA_CHECK((x) != (y)) 27 | #define DA_CHECK_LE(x, y) DA_CHECK((x) <= (y)) 28 | #define DA_CHECK_LT(x, y) DA_CHECK((x) < (y)) 29 | #define DA_CHECK_GE(x, y) DA_CHECK((x) >= (y)) 30 | #define DA_CHECK_GT(x, y) DA_CHECK((x) > (y)) 31 | 32 | #define DA_DISALLOW_COPY_AND_ASSIGN(TypeName) \ 33 | TypeName(const TypeName &) = delete; \ 34 | void operator=(const TypeName &) = delete 35 | 36 | #define DA_CHECK_CUDART_ERROR(_expr_) \ 37 | do { \ 38 | cudaError_t _ret_ = _expr_; \ 39 | if (DA_UNLIKELY(_ret_ != cudaSuccess)) { \ 40 | const char *_err_str_ = cudaGetErrorName(_ret_); \ 41 | int _rt_version_ = 0; \ 42 | cudaRuntimeGetVersion(&_rt_version_); \ 43 | int _driver_version_ = 0; \ 44 | cudaDriverGetVersion(&_driver_version_); \ 45 | DLOG("CUDA Runtime API error = %04d \"%s\", runtime version: %d, driver version: %d", \ 46 | static_cast(_ret_), _err_str_, _rt_version_, _driver_version_); \ 47 | exit(EXIT_FAILURE); \ 48 | } \ 49 | } while (0) 50 | 51 | #define DA_BOOL_SWITCH(_cond_, _var_, ...) \ 52 | [&] { \ 53 | if (_cond_) { \ 54 | constexpr static bool _var_ = true; \ 55 | return __VA_ARGS__(); \ 56 | } else { \ 57 | constexpr static bool _var_ = false; \ 58 | return __VA_ARGS__(); \ 59 | } \ 60 | }() 61 | 62 | #define DA_FP16_SWITCH(_cond_, ...) \ 63 | [&] { \ 64 | if (_cond_) { \ 65 | using elem_type = half; \ 66 | return __VA_ARGS__(); \ 67 | } else { \ 68 | using elem_type = __nv_bfloat16; \ 69 | return __VA_ARGS__(); \ 70 | } \ 71 | }() 72 | 73 | #define DA_HEADDIM_SWITCH(d, d_v, ...) \ 74 | [&] { \ 75 | if (d == 64 && d_v == 64) { \ 76 | constexpr static size_t head_dim = 64; \ 77 | constexpr static size_t head_dim_v = 64; \ 78 | return __VA_ARGS__(); \ 79 | } else if (d == 96 && d_v == 96) { \ 80 | constexpr static size_t head_dim = 96; \ 81 | constexpr static size_t head_dim_v = 96; \ 82 | return __VA_ARGS__(); \ 83 | } else if (d == 128 && d_v == 128) { \ 84 | constexpr static size_t head_dim = 128; \ 85 | constexpr static size_t head_dim_v = 128; \ 86 | return __VA_ARGS__(); \ 87 | } else if (d == 256 && d_v == 256) { \ 88 | constexpr static size_t head_dim = 256; \ 89 | constexpr static size_t head_dim_v = 256; \ 90 | return __VA_ARGS__(); \ 91 | } else if (d == 576 && d_v == 512) { \ 92 | constexpr static size_t head_dim = 576; \ 93 | constexpr static size_t head_dim_v = 512; \ 94 | return __VA_ARGS__(); \ 95 | } \ 96 | }() 97 | -------------------------------------------------------------------------------- /tests/test_decoding_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: test decoding attn python api 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | import math 16 | import pytest 17 | import torch 18 | from torch.nn import functional as F 19 | 20 | from decoding_attn import decoding_attn_fwd 21 | 22 | 23 | def get_cu_seq(seqs: torch.tensor) -> torch.tensor: 24 | """ 25 | Arguments: 26 | seqs: [batch], dtype torch.int32, sequence length of each batch. 27 | Return: 28 | cu_seq: [batch + 1], dtype torch.int32. The cumulative sequence lengths of the sequences 29 | in the batch. 30 | """ 31 | return F.pad(seqs.cumsum(dim=0, dtype=torch.int32), (1, 0)) 32 | 33 | 34 | # sopprt MHA/MQA/GQA/MLA 35 | def torch_attn(q: torch.tensor, k: torch.tensor, v: torch.tensor = None, dim_v: int = None) -> torch.tensor: 36 | """ 37 | Arguments: 38 | q: [batch, seq_q, head_q, dim] 39 | k: [batch, seq_k, head_k, dim], kv_c_and_k_pe_cache for MLA. 40 | v: [batch, seq_k, head_k, dim], None for MLA. 41 | dim_v: dim of v, not None for MLA. 42 | Return: 43 | o: [batch, seq_q, head_q, dim_v] 44 | """ 45 | v = v if v is not None else k[..., :dim_v] 46 | head_q = q.shape[2] 47 | dim = q.shape[3] 48 | head_k = k.shape[2] 49 | head_ratio = head_q // head_k 50 | qt = q.transpose(1, 2) 51 | kt = k.transpose(1, 2).repeat_interleave(head_ratio, dim=1) 52 | vt = v.transpose(1, 2).repeat_interleave(head_ratio, dim=1) 53 | s = torch.matmul(qt, kt.transpose(-2, -1)) / math.sqrt(dim) 54 | p = F.softmax(s, dim=-1) 55 | o = torch.matmul(p, vt) 56 | 57 | return o.transpose(1, 2) 58 | 59 | 60 | # test MHA/GQA 61 | @pytest.mark.parametrize("head_q", [32, 64]) 62 | @pytest.mark.parametrize("head_k", [8, 32]) 63 | @pytest.mark.parametrize("dim", [64, 96, 128, 256]) 64 | @pytest.mark.parametrize("batch", [1, 2, 16]) 65 | @pytest.mark.parametrize("seq_k", [1, 128, 512]) 66 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 67 | def test_decoding_attn_fwd(head_q, head_k, dim, batch, seq_k, dtype): 68 | seq_q = 1 69 | total_q = batch * seq_q 70 | total_k = batch * seq_k 71 | is_alibi = False 72 | 73 | q = torch.randn(total_q, head_q, dim, 74 | device=torch.device('cuda'), dtype=dtype) 75 | k = torch.randn(total_k, head_k, dim, 76 | device=torch.device('cuda'), dtype=dtype) 77 | v = torch.randn(total_k, head_k, dim, 78 | device=torch.device('cuda'), dtype=dtype) 79 | 80 | cu_seq_k = get_cu_seq(torch.full( 81 | (batch,), seq_k, dtype=torch.int32, device=torch.device('cuda'))) 82 | 83 | q4 = q.reshape(batch, seq_q, head_q, dim) 84 | k4 = k.reshape(batch, seq_k, head_k, dim) 85 | v4 = v.reshape(batch, seq_k, head_k, dim) 86 | 87 | attn = torch_attn(q4, k4, v=v4) 88 | output = attn.reshape(total_q, head_q, dim) 89 | print(f"Attn-CPU output: {output}") 90 | 91 | da_output = decoding_attn_fwd(q, k, v, cu_seq_k, seq_k, dim, is_alibi) 92 | print(f"Decoding-Attention output: {da_output}") 93 | 94 | assert (output - da_output).abs().mean().item() <= 5e-3 95 | 96 | 97 | @pytest.mark.parametrize("head_q", [16, 32, 64, 128]) 98 | @pytest.mark.parametrize("head_k", [1]) 99 | @pytest.mark.parametrize("dim", [576]) 100 | @pytest.mark.parametrize("dim_v", [512]) 101 | @pytest.mark.parametrize("batch", [1, 2, 16]) 102 | @pytest.mark.parametrize("seq_k", [1, 128, 512]) 103 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 104 | def test_decoding_mla_fwd(head_q, head_k, dim, dim_v, batch, seq_k, dtype): 105 | seq_q = 1 106 | total_q = batch * seq_q 107 | total_k = batch * seq_k 108 | is_alibi = False 109 | 110 | q = torch.randn(total_q, head_q, dim, 111 | device=torch.device('cuda'), dtype=dtype) 112 | k = torch.randn(total_k, head_k, dim, 113 | device=torch.device('cuda'), dtype=dtype) 114 | 115 | cu_seq_k = get_cu_seq(torch.full( 116 | (batch,), seq_k, dtype=torch.int32, device=torch.device('cuda'))) 117 | 118 | q4 = q.reshape(batch, seq_q, head_q, dim) 119 | k4 = k.reshape(batch, seq_k, head_k, dim) 120 | 121 | attn = torch_attn(q4, k4, dim_v=dim_v) 122 | output = attn.reshape(total_q, head_q, dim_v) 123 | print(f"MLA-CPU output: {output}") 124 | 125 | da_output = decoding_attn_fwd(q, k, None, cu_seq_k, seq_k, dim_v, is_alibi) 126 | print(f"Decoding-Attention output: {da_output}") 127 | 128 | assert (output - da_output).abs().mean().item() <= 5e-3 129 | 130 | 131 | def main(): 132 | test_decoding_attn_fwd(32, 32, 128, 2, 128, torch.float16) 133 | test_decoding_mla_fwd(128, 1, 576, 512, 2, 128, torch.float16) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: setup decoding attention 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | import os 11 | import sys 12 | from pathlib import Path 13 | from setuptools import setup, find_packages 14 | 15 | import torch 16 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 17 | 18 | 19 | with open("README.md", "r", encoding="utf-8") as fh: 20 | long_description = fh.read() 21 | 22 | 23 | def append_nvcc_threads(nvcc_extra_args): 24 | nvcc_threads = os.getenv("NVCC_THREADS") or "4" 25 | return nvcc_extra_args + ["--threads", nvcc_threads] 26 | 27 | 28 | class NinjaBuildExtension(BuildExtension): 29 | def __init__(self, *args, **kwargs) -> None: 30 | # do not override env MAX_JOBS if already exists 31 | if not os.environ.get("MAX_JOBS"): 32 | import psutil 33 | 34 | # calculate the maximum allowed NUM_JOBS based on cores 35 | max_num_jobs_cores = max(1, os.cpu_count() // 2) 36 | 37 | # calculate the maximum allowed NUM_JOBS based on free memory 38 | free_memory_gb = psutil.virtual_memory().available / \ 39 | (1024 ** 3) # free memory in GB 40 | # each JOB peak memory cost is ~8-9GB when threads = 4 41 | max_num_jobs_memory = int(free_memory_gb / 9) 42 | 43 | # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation 44 | max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) 45 | os.environ["MAX_JOBS"] = str(max_jobs) 46 | 47 | super().__init__(*args, **kwargs) 48 | 49 | 50 | print("python version: {}".format(sys.version)) 51 | print("torch version: {}".format(torch.__version__)) 52 | 53 | # ninja build does not work unless include_dirs are abs path 54 | this_dir = os.path.dirname(os.path.abspath(__file__)) 55 | 56 | # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h 57 | # See https://github.com/pytorch/pytorch/pull/70650 58 | generator_flag = [] 59 | torch_dir = torch.__path__[0] 60 | if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): 61 | generator_flag = ["-DOLD_GENERATOR_PATH"] 62 | 63 | cc_flag = [] 64 | cc_flag.append("-gencode") 65 | cc_flag.append("arch=compute_80,code=sm_80") 66 | cc_flag.append("-gencode") 67 | cc_flag.append("arch=compute_86,code=sm_86") 68 | cc_flag.append("-gencode") 69 | cc_flag.append("arch=compute_89,code=sm_89") 70 | cc_flag.append("-gencode") 71 | cc_flag.append("arch=compute_90,code=sm_90") 72 | 73 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 74 | # torch._C._GLIBCXX_USE_CXX11_ABI 75 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 76 | torch._C._GLIBCXX_USE_CXX11_ABI = False 77 | 78 | ext_modules = [] 79 | ext_modules.append( 80 | CUDAExtension( 81 | name="decoding_attn_cuda", 82 | sources=[ 83 | "csrc/torch/decoding_torch.cpp", 84 | "csrc/ops/decoding_attn.cpp", 85 | "csrc/kernel/decoding_attn/decoding_fwd_hd64_hdv64.cu", 86 | "csrc/kernel/decoding_attn/decoding_fwd_hd96_hdv96.cu", 87 | "csrc/kernel/decoding_attn/decoding_fwd_hd128_hdv128.cu", 88 | "csrc/kernel/decoding_attn/decoding_fwd_hd256_hdv256.cu", 89 | "csrc/kernel/decoding_attn/decoding_fwd_hd576_hdv512.cu", 90 | ], 91 | extra_compile_args={ 92 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 93 | "nvcc": append_nvcc_threads( 94 | [ 95 | "-O3", 96 | "-std=c++17", 97 | "-U__CUDA_NO_HALF_OPERATORS__", 98 | "-U__CUDA_NO_HALF_CONVERSIONS__", 99 | "-U__CUDA_NO_HALF2_OPERATORS__", 100 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 101 | "--expt-relaxed-constexpr", 102 | "--expt-extended-lambda", 103 | "--use_fast_math", 104 | # "--ptxas-options=-v", 105 | # "--ptxas-options=-O2", 106 | # "-lineinfo", 107 | ] 108 | + generator_flag 109 | + cc_flag 110 | ), 111 | }, 112 | include_dirs=[ 113 | Path(this_dir) / "csrc" / "common", 114 | Path(this_dir) / "csrc" / "kernel", 115 | Path(this_dir) / "csrc" / "ops", 116 | ], 117 | ) 118 | ) 119 | 120 | 121 | setup( 122 | name="decoding_attn", 123 | version="0.1.0", 124 | packages=find_packages( 125 | exclude=( 126 | "csrc", 127 | "decodng_attn", 128 | ) 129 | ), 130 | author="Bruce-Lee-LY", 131 | description="Decoding Attention", 132 | long_description=long_description, 133 | long_description_content_type="text/markdown", 134 | url="https://github.com/Bruce-Lee-LY/decoding_attention", 135 | classifiers=[ 136 | "Programming Language :: Python :: 3", 137 | "License :: OSI Approved :: BSD License", 138 | "Operating System :: Unix", 139 | ], 140 | ext_modules=ext_modules, 141 | cmdclass={"build_ext": NinjaBuildExtension}, 142 | python_requires=">=3.8", 143 | install_requires=[ 144 | "torch", 145 | ], 146 | setup_requires=[ 147 | "psutil", 148 | "ninja", 149 | ], 150 | ) 151 | -------------------------------------------------------------------------------- /run_cpp.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: run cpp script 6 | 7 | #!/bin/bash 8 | 9 | set -euo pipefail 10 | 11 | WORK_PATH=$(cd $(dirname $0) && pwd) && cd $WORK_PATH 12 | 13 | export CUDA_VISIBLE_DEVICES=0 14 | # export CUDA_LAUNCH_BLOCKING=1 15 | 16 | rm -rf log ncu && mkdir -p log ncu 17 | 18 | # $1: b, $2: sq, $3: sk, $4: hq, $5: hk, $6: d, $7: dv, $8: is_causal, $9: log_path 19 | evaluate_da() { 20 | echo "Evaluating ${1} * ${2} * ${3} * ${4} * ${5} * ${6} * ${7} * ${8} * ${9}" 21 | $WORK_PATH/output/bin/benchmark_decoding_attn -b=$1 -sq=$2 -sk=$3 -hq=$4 -hk=$5 -d=$6 -dv=$7 -is_causal=$8 -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=false > log/${9}/da_${1}_${2}_${3}_${4}_${5}_${6}_${7}.log 2>&1 22 | sleep 3 23 | } 24 | 25 | # $1: b, $2: sq, $3: sk, $4: hq, $5: hk, $6: d, $7: dv, $8: is_causal, $9: log_path 26 | ncu_da() { 27 | echo "NCU ${1} * ${2} * ${3} * ${4} * ${5} * ${6} * ${7} * ${8} * ${9}" 28 | sudo ncu --set full --target-processes all --force-overwrite -o ncu/${8}/da_${1}_${2}_${3}_${4}_${5}_${6}_${7} $WORK_PATH/output/bin/benchmark_decoding_attn -b=$1 -sq=$2 -sk=$3 -hq=$4 -hk=$5 -d=$6 -dv=$7 -is_causal=$8 -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/${9}/ncu_da_${1}_${2}_${3}_${4}_${5}_${6}_${7}.log 2>&1 29 | sleep 3 30 | } 31 | 32 | # $1: hq, $2: hk, $3: d, $4: dv 33 | benchmark_da_decoding_seq() { 34 | echo "Evaluating Decoding Seq" 35 | b=1 36 | sq=1 37 | seq_k=(1 8 16 32 64 128 256 512 1024 2048 3072 4096 5120 6144 7168 8192) 38 | hq=$1 39 | hk=$2 40 | d=$3 41 | dv=$4 42 | ic=false 43 | lp=decoding_seq 44 | 45 | mkdir -p log/$lp ncu/$lp 46 | 47 | for sk in ${seq_k[@]}; 48 | do 49 | evaluate_da $b $sq $sk $hq $hk $d $dv $ic $lp 50 | # ncu_da $b $sq $sk $hq $hk $d $dv $ic $lp 51 | done 52 | } 53 | 54 | # $1: hq, $2: hk, $3: d, $4: dv 55 | benchmark_da_decoding_batch() { 56 | echo "Evaluating Decoding Batch" 57 | batch=(1 2 4 8 16 32 64 128 256 512 768 1024 1536 2048) 58 | sq=1 59 | sk=128 60 | hq=$1 61 | hk=$2 62 | d=$3 63 | dv=$4 64 | ic=false 65 | lp=decoding_batch 66 | 67 | mkdir -p log/$lp ncu/$lp 68 | 69 | for b in ${batch[@]}; 70 | do 71 | evaluate_da $b $sq $sk $hq $hk $d $dv $ic $lp 72 | # ncu_da $b $sq $sk $hq $hk $d $dv $ic $lp 73 | done 74 | } 75 | 76 | benchmark_da() { 77 | benchmark_da_decoding_seq 32 32 128 128 78 | benchmark_da_decoding_batch 32 32 128 128 79 | } 80 | 81 | benchmark_mla() { 82 | benchmark_da_decoding_seq 128 1 576 512 83 | benchmark_da_decoding_batch 128 1 576 512 84 | } 85 | 86 | # FP16 87 | # nohup $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=false -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=true > log/da_2_1_256_32_32_128_128.log 2>&1 & 88 | # sudo ncu --set full --target-processes all --force-overwrite -o ncu/da_2_1_256_32_32_128_128 $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=false -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/ncu_da_2_1_256_32_32_128_128.log 2>&1 89 | 90 | # BF16 91 | # nohup $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=false -is_alibi=false -is_bf16=true -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=true > log/da_2_1_256_32_32_128_128.log 2>&1 & 92 | # sudo ncu --set full --target-processes all --force-overwrite -o ncu/da_2_1_256_32_32_128_128 $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=false -is_alibi=true -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/ncu_da_2_1_256_32_32_128_128.log 2>&1 93 | 94 | # GQA/MQA 95 | # nohup $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=64 -hk=8 -d=128 -dv=128 -is_causal=false -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=true > log/da_2_1_256_64_8_128_128.log 2>&1 & 96 | # sudo ncu --set full --target-processes all --force-overwrite -o ncu/da_2_1_256_64_8_128_128 $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=256 -sk=256 -hq=64 -hk=8 -d=128 -dv=128 -is_causal=true -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/ncu_da_2_1_256_64_8_128_128.log 2>&1 97 | 98 | # Alibi 99 | # nohup $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=false -is_alibi=true -is_bf16=false -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=true > log/da_2_1_256_32_32_128_128.log 2>&1 & 100 | # sudo ncu --set full --target-processes all --force-overwrite -o ncu/da_2_1_256_32_32_128_128 $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=256 -sk=256 -hq=32 -hk=32 -d=128 -dv=128 -is_causal=true -is_alibi=true -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/ncu_da_2_1_256_32_32_128_128.log 2>&1 101 | 102 | # MLA 103 | nohup $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=128 -hk=1 -d=576 -dv=512 -is_causal=false -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=10 -sleep_duration=100 -enable_check=true > log/da_2_1_256_128_1_576_512.log 2>&1 & 104 | # sudo ncu --set full --target-processes all --force-overwrite -o ncu/da_2_1_256_128_1_576_512 $WORK_PATH/output/bin/benchmark_decoding_attn -b=2 -sq=1 -sk=256 -hq=128 -hk=1 -d=576 -dv=512 -is_causal=false -is_alibi=false -is_bf16=false -warmup_iterations=1 -profiling_iterations=1 -sleep_duration=100 -enable_check=false > log/ncu_da_2_1_256_128_1_576_512.log 2>&1 105 | 106 | # benchmark_da 107 | # benchmark_mla 108 | -------------------------------------------------------------------------------- /tools/performance/cpp/performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 20:42:28 on Sun, Feb 12, 2023 4 | # 5 | # Description: performance cpp line chart 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | import os 16 | import optparse 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def process_type(path): 22 | dim_idx = 0 23 | data_type = '' 24 | x_label = '' 25 | 26 | if 'decoding_seq' in path: 27 | dim_idx = 3 28 | data_type = 'decoding_seq' 29 | x_label = 'Seq Len' 30 | elif 'decoding_batch' in path: 31 | dim_idx = 1 32 | data_type = 'decoding_batch' 33 | x_label = 'Batch Size' 34 | else: 35 | raise TypeError('Unsupported data type') 36 | 37 | return dim_idx, data_type, x_label 38 | 39 | 40 | def get_methods(log_file): 41 | methods = [] 42 | 43 | with open(log_file) as fp: 44 | line = fp.readline() 45 | while line: 46 | if 'exit' in line and 'Naive' not in line: 47 | iterms = line.split(' ') 48 | methods.append(iterms[6]) 49 | 50 | line = fp.readline() 51 | 52 | return methods 53 | 54 | 55 | def get_dims(log_files, dim_idx): 56 | dims = [] 57 | 58 | for log_file in log_files: 59 | dims.append(int((log_file.split('.')[0]).split('_')[dim_idx])) 60 | 61 | dims.sort() 62 | 63 | return dims 64 | 65 | 66 | def read_data(methods, dims, data_path, log_files, dim_idx): 67 | throughputs = np.zeros((len(methods), len(dims)), np.float64) 68 | throughputs_performance = np.zeros((len(methods), len(dims)), np.float64) 69 | bandwidths = np.zeros((len(methods), len(dims)), np.float64) 70 | bandwidths_performance = np.zeros((len(methods), len(dims)), np.float64) 71 | 72 | for log_file in log_files: 73 | dim = int((log_file.split('.')[0]).split('_')[dim_idx]) 74 | with open(data_path + log_file) as fp: 75 | line = fp.readline() 76 | while line: 77 | if 'exit' in line and 'Naive' not in line: 78 | iterms = line.split(' ') 79 | method = iterms[6] 80 | throughputs[methods.index( 81 | method)][dims.index(dim)] = float(iterms[14]) 82 | throughputs_performance[methods.index(method)][dims.index(dim)] = float( 83 | iterms[16].replace('(', '').replace(')', '').replace('%', '').replace(',', '')) 84 | bandwidths[methods.index( 85 | method)][dims.index(dim)] = float(iterms[18]) 86 | bandwidths_performance[methods.index(method)][dims.index(dim)] = float( 87 | iterms[20].replace('(', '').replace(')', '').replace('%', '').replace(',', '')) 88 | line = fp.readline() 89 | 90 | return throughputs, throughputs_performance, bandwidths, bandwidths_performance 91 | 92 | 93 | def draw_line_chart(methods, dims, data, figure_name, y_step, x_label, y_label, title): 94 | fig = plt.figure(figsize=(32, 24), dpi=100) 95 | 96 | dims_str = list(map(str, dims)) 97 | 98 | colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] 99 | linestyles = ['-', '--', '-.', ':'] 100 | 101 | for i in range(len(methods)): 102 | plt.plot(dims_str, data[i], color=colors[i % len(colors)], 103 | linestyle=linestyles[(i // len(colors)) % len(linestyles)], marker='o', markersize=6) 104 | 105 | # plt.xticks(dims) 106 | plt.ylim(bottom=0) 107 | plt.yticks( 108 | range(0, round(np.max(np.max(data, axis=0)) + 0.5) + y_step, y_step)) 109 | plt.tick_params(labelsize=25) 110 | 111 | # plt.hlines(y=100, xmin=dims_str[0], xmax=dims_str[-1], colors='r', linestyles='-.') 112 | plt.grid(True, linestyle='-.') 113 | 114 | plt.xlabel(x_label, fontdict={'size': '30'}) 115 | plt.ylabel(y_label, fontdict={'size': '30'}) 116 | plt.title(title, fontdict={'size': '30'}) 117 | plt.legend(methods, loc='best', prop={'size': '30'}) 118 | 119 | plt.savefig(figure_name, dpi=fig.dpi) 120 | # plt.show() 121 | 122 | 123 | def analyze_data(data_path, dim_idx, data_type, x_label): 124 | log_files = [] 125 | for file_name in os.listdir(data_path): 126 | if '.log' not in file_name: 127 | continue 128 | 129 | log_files.append(file_name) 130 | 131 | methods = get_methods(data_path + log_files[0]) 132 | dims = get_dims(log_files, dim_idx) 133 | throughputs, throughputs_performance, bandwidths, bandwidths_performance = read_data( 134 | methods, dims, data_path, log_files, dim_idx) 135 | draw_line_chart(methods, dims, throughputs, data_path + data_type + 136 | '_throughput.png', 1, x_label, 'Throughput / TFLOPS', 'Decoding Attention Throughput') 137 | draw_line_chart(methods, dims, throughputs_performance, data_path + data_type + '_throughput_performance.png', 20, x_label, 138 | 'Performance Compared with Decoding Attention / %', 'Decoding Attention Throughput Performance') 139 | draw_line_chart(methods, dims, bandwidths, data_path + data_type + 140 | '_bandwidth.png', 2, x_label, 'Bandwidth / GB/s', 'Decoding Attention Bandwidth') 141 | draw_line_chart(methods, dims, bandwidths_performance, data_path + data_type + '_bandwidth_performance.png', 20, x_label, 142 | 'Performance Compared with Decoding Attention / %', 'Decoding Attention Bandwidth Performance') 143 | 144 | 145 | def main(): 146 | usage = "python3 performance.py -p log/" 147 | parser = optparse.OptionParser(usage) 148 | parser.add_option('-p', '--path', dest='path', 149 | type='string', help='data path', default='log/') 150 | 151 | options, args = parser.parse_args() 152 | path = options.path 153 | 154 | dim_idx, data_type, x_label = process_type(path) 155 | analyze_data(path, dim_idx, data_type, x_label) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /benchmarks/cpp/benchmark_decoding_attn.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:08:30 on Sun, Aug 27, 2023 4 | // 5 | // Description: benchmark decoding attn using cpp api 6 | 7 | #include "decoding_attn.h" 8 | #include "gflags/gflags.h" 9 | #include "omp.h" 10 | #include "tester.h" 11 | 12 | DEFINE_uint32(b, 2, "batch size"); 13 | DEFINE_uint32(sq, 1, "q seq len"); 14 | DEFINE_uint32(sk, 256, "kv seq len"); 15 | DEFINE_uint32(hq, 32, "q head num"); 16 | DEFINE_uint32(hk, 32, "kv head num"); 17 | DEFINE_uint32(d, 128, "head dim"); 18 | DEFINE_uint32(dv, 128, "head dim v"); 19 | DEFINE_bool(is_causal, false, "causal mask"); 20 | DEFINE_bool(is_alibi, false, "enable alibi"); 21 | DEFINE_bool(is_bf16, false, "data type of q, k, v and o"); 22 | DEFINE_uint32(warmup_iterations, 1, "warmup iteration numbers and average the result"); 23 | DEFINE_uint32(profiling_iterations, 10, "profiling iteration numbers and average the result"); 24 | DEFINE_uint32(sleep_duration, 100, "sleep_milliseconds between profiling"); 25 | DEFINE_bool(enable_check, false, "check the GPU result against the CPU result"); 26 | DEFINE_uint32(cpu_procs, omp_get_num_procs(), "processor num used of CPU"); 27 | DEFINE_uint32(gpu_rank, 0, "the used GPU rank"); 28 | 29 | template 30 | void test_decoding_attn(size_t batch = 2, size_t seq_q = 1, size_t seq_k = 256, size_t head_q = 32, size_t head_k = 32, 31 | size_t dim = 128, size_t dim_v = 128, bool is_causal = false, bool is_alibi = false, 32 | cudaStream_t stream = nullptr, size_t warmup_iterations = 1, size_t profiling_iterations = 10, 33 | size_t sleep_duration = 100, bool enable_check = false) { 34 | Tester tester(batch, seq_q, seq_k, head_q, head_k, dim, dim_v, is_causal, is_alibi, stream, warmup_iterations, 35 | profiling_iterations, sleep_duration, enable_check); 36 | tester.evaluate(decoding_attn, "Decoding-Attention"); 37 | } 38 | 39 | int main(int argc, char *argv[]) { 40 | GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); 41 | 42 | omp_set_num_threads(FLAGS_cpu_procs); 43 | DA_CHECK_CUDART_ERROR(cudaSetDevice(FLAGS_gpu_rank)); 44 | 45 | cudaDeviceProp dev_prop; 46 | DA_CHECK_CUDART_ERROR(cudaGetDeviceProperties(&dev_prop, FLAGS_gpu_rank)); 47 | DLOG("Decoding Attention start with %u CPU processes on the %u-th GPU: %s", FLAGS_cpu_procs, FLAGS_gpu_rank, 48 | dev_prop.name); 49 | 50 | int driver_version = 0; 51 | int runtime_version = 0; 52 | DA_CHECK_CUDART_ERROR(cudaDriverGetVersion(&driver_version)); 53 | DA_CHECK_CUDART_ERROR(cudaRuntimeGetVersion(&runtime_version)); 54 | DLOG("CUDA driver version / runtime version: %d.%d / %d.%d", driver_version / 1000, (driver_version % 100) / 10, 55 | runtime_version / 1000, (runtime_version % 100) / 10); 56 | 57 | DLOG("CUDA capability major/minor version number: %d.%d", dev_prop.major, dev_prop.minor); 58 | DLOG("%d multiprocessors, %d CUDA cores/MP: %d CUDA cores", dev_prop.multiProcessorCount, 59 | convert_SM_to_cores(dev_prop.major, dev_prop.minor), 60 | convert_SM_to_cores(dev_prop.major, dev_prop.minor) * dev_prop.multiProcessorCount); 61 | DLOG("GPU max clock rate: %.0f MHz (%0.2f GHz)", static_cast(dev_prop.clockRate) * 1e-3, 62 | static_cast(dev_prop.clockRate) * 1e-6); 63 | DLOG("Memory clock rate: %.0f MHz (%0.2f GHz)", static_cast(dev_prop.memoryClockRate) * 1e-3, 64 | static_cast(dev_prop.memoryClockRate) * 1e-6); 65 | DLOG("Memory bus width: %d-bit", dev_prop.memoryBusWidth); 66 | DLOG("Total amount of global memory: %.0f MBytes (%zu Bytes)", 67 | static_cast(dev_prop.totalGlobalMem) / 1048576, dev_prop.totalGlobalMem); 68 | DLOG("Total amount of constant memory: %.0f KBytes (%zu Bytes)", static_cast(dev_prop.totalConstMem) / 1024, 69 | dev_prop.totalConstMem); 70 | DLOG("Total amount of shared memory per block: %.0f KBytes (%zu Bytes)", 71 | static_cast(dev_prop.sharedMemPerBlock) / 1024, dev_prop.sharedMemPerBlock); 72 | DLOG("Total shared memory per multiprocessor: %.0f KBytes (%zu Bytes)", 73 | static_cast(dev_prop.sharedMemPerMultiprocessor) / 1024, dev_prop.sharedMemPerMultiprocessor); 74 | DLOG("L2 cache size: %.0f KBytes (%d Bytes)", static_cast(dev_prop.l2CacheSize) / 1024, 75 | dev_prop.l2CacheSize); 76 | DLOG("Total number of registers available per block: %d", dev_prop.regsPerBlock); 77 | DLOG("Warp size: %d", dev_prop.warpSize); 78 | DLOG("Max number of threads per multiprocessor: %d", dev_prop.maxThreadsPerMultiProcessor); 79 | DLOG("Max number of threads per block: %d", dev_prop.maxThreadsPerBlock); 80 | DLOG("Max dimension size of a thread block (x,y,z): (%d, %d, %d)", dev_prop.maxThreadsDim[0], 81 | dev_prop.maxThreadsDim[1], dev_prop.maxThreadsDim[2]); 82 | DLOG("Max dimension size of a grid size (x,y,z): (%d, %d, %d)", dev_prop.maxGridSize[0], dev_prop.maxGridSize[1], 83 | dev_prop.maxGridSize[2]); 84 | 85 | cudaStream_t stream = nullptr; 86 | 87 | DLOG( 88 | "DMHA: Softmax (Q (%u x %u x %u x %u) * K^T (%u x %u x %u x %u)) * V (%u x %u x %u x %u) = O (%u x %u x %u x " 89 | "%u)", 90 | FLAGS_b, FLAGS_sq, FLAGS_hq, FLAGS_d, FLAGS_b, FLAGS_sk, FLAGS_hk, FLAGS_d, FLAGS_b, FLAGS_sk, FLAGS_hk, 91 | FLAGS_dv, FLAGS_b, FLAGS_sq, FLAGS_hq, FLAGS_dv); 92 | DLOG( 93 | "Profiling: is causal: %d, stream: %p, is alibi: %d, is bf16: %d, warmup iterations: %u, profiling iterations: " 94 | "%u, sleep duration: %u ms, enable check: %d", 95 | FLAGS_is_causal, stream, FLAGS_is_alibi, FLAGS_is_bf16, FLAGS_warmup_iterations, FLAGS_profiling_iterations, 96 | FLAGS_sleep_duration, FLAGS_enable_check); 97 | 98 | if (FLAGS_is_bf16) { 99 | test_decoding_attn<__nv_bfloat16>(FLAGS_b, FLAGS_sq, FLAGS_sk, FLAGS_hq, FLAGS_hk, FLAGS_d, FLAGS_dv, 100 | FLAGS_is_causal, FLAGS_is_alibi, stream, FLAGS_warmup_iterations, 101 | FLAGS_profiling_iterations, FLAGS_sleep_duration, FLAGS_enable_check); 102 | } else { 103 | test_decoding_attn(FLAGS_b, FLAGS_sq, FLAGS_sk, FLAGS_hq, FLAGS_hk, FLAGS_d, FLAGS_dv, FLAGS_is_causal, 104 | FLAGS_is_alibi, stream, FLAGS_warmup_iterations, FLAGS_profiling_iterations, 105 | FLAGS_sleep_duration, FLAGS_enable_check); 106 | } 107 | 108 | GFLAGS_NAMESPACE::ShutDownCommandLineFlags(); 109 | 110 | DLOG("Done"); 111 | 112 | return 0; 113 | } 114 | -------------------------------------------------------------------------------- /csrc/kernel/decoding_attn/decoding_fwd_kernel.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:14:13 on Tue, Oct 31, 2023 4 | // 5 | // Description: decoding fwd kernel 6 | 7 | #pragma once 8 | 9 | #include "decoding_attn/block_info.h" 10 | #include "decoding_attn/decoding.h" 11 | #include "decoding_attn/kernel_traits.h" 12 | 13 | template 14 | __global__ void dmha_fwd_kernel(const DecodingParams params) { 15 | const DecodingBlockInfo binfo(params, blockIdx.x, blockIdx.y); 16 | if (binfo.actual_seq_k <= 0) { 17 | return; 18 | } 19 | 20 | constexpr size_t head_dim_v = KernelTraits::head_dim_v; 21 | constexpr size_t threads_per_block = KernelTraits::threads_per_block; 22 | constexpr size_t threads_per_group = KernelTraits::threads_per_group; 23 | 24 | constexpr size_t warp_size = KernelTraits::warp_size; 25 | constexpr size_t warps_per_block = KernelTraits::warps_per_block; 26 | 27 | constexpr size_t groups_per_warp = KernelTraits::groups_per_warp; 28 | constexpr size_t groups_per_block = KernelTraits::groups_per_block; 29 | 30 | constexpr size_t thread_copy_elem_nums = KernelTraits::thread_copy_elem_nums; 31 | 32 | constexpr size_t thread_qk_nums = KernelTraits::thread_qk_nums; 33 | constexpr size_t thread_copy_qk_iters = KernelTraits::thread_copy_qk_iters; 34 | 35 | constexpr size_t thread_vo_nums = KernelTraits::thread_vo_nums; 36 | constexpr size_t thread_copy_vo_iters = KernelTraits::thread_copy_vo_iters; 37 | 38 | constexpr unsigned int shfl_mask = KernelTraits::shfl_mask; 39 | 40 | const size_t warp_id = threadIdx.x / warp_size; 41 | const size_t lane_id = threadIdx.x % warp_size; 42 | const size_t group_id = lane_id / threads_per_group; 43 | const size_t group_lane_id = lane_id % threads_per_group; 44 | 45 | T *q_ptr = reinterpret_cast(params.q_ptr); 46 | T *k_ptr = reinterpret_cast(params.k_ptr); 47 | T *v_ptr = reinterpret_cast(params.v_ptr); 48 | T *o_ptr = reinterpret_cast(params.o_ptr); 49 | 50 | // S = Q * K^T 51 | T RQ[thread_qk_nums]; 52 | 53 | #pragma unroll 54 | for (size_t i = 0; i < thread_copy_qk_iters; ++i) { 55 | *(int4 *)(&RQ[i * thread_copy_elem_nums]) = 56 | *(int4 *)(&q_ptr[binfo.q_offset(params.q_row_stride, params.q_head_stride, 57 | (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]); 58 | } 59 | 60 | extern __shared__ float S_smem[]; 61 | float S_max = -std::numeric_limits::max(); 62 | 63 | #pragma unroll 64 | for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k; 65 | base_seq_k += groups_per_block) { 66 | size_t seq_k = base_seq_k + group_id; 67 | T RK[thread_qk_nums]; 68 | 69 | float acc = 0.0; 70 | if (seq_k < binfo.actual_seq_k) { 71 | #pragma unroll 72 | for (size_t i = 0; i < thread_copy_qk_iters; ++i) { 73 | *(int4 *)(&RK[i * thread_copy_elem_nums]) = 74 | *(int4 *)(&k_ptr[binfo.k_offset(seq_k, params.k_row_stride, params.k_head_stride, 75 | (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]); 76 | } 77 | 78 | #pragma unroll 79 | for (size_t i = 0; i < thread_qk_nums; ++i) { 80 | if constexpr (std::is_same_v) { 81 | acc += (__half2float(RQ[i]) * __half2float(RK[i])); 82 | } else { 83 | acc += (__bfloat162float(RQ[i]) * __bfloat162float(RK[i])); 84 | } 85 | } 86 | } 87 | 88 | #pragma unroll 89 | for (size_t i = threads_per_group / 2; i >= 1; i /= 2) { 90 | acc += __shfl_xor_sync(shfl_mask, acc, i); 91 | } 92 | 93 | if (group_lane_id == 0 && seq_k < binfo.actual_seq_k) { 94 | acc *= params.scale_softmax; 95 | 96 | if constexpr (IsAlibi) { 97 | acc += (binfo.h_slope * (static_cast(seq_k) - binfo.actual_seq_q - binfo.row_shift)); 98 | } 99 | 100 | S_smem[seq_k] = acc; 101 | S_max = fmaxf(acc, S_max); 102 | } 103 | } 104 | 105 | // P = Softmax(S) 106 | __shared__ float softmax_smem[warps_per_block]; 107 | 108 | #pragma unroll 109 | for (size_t i = warp_size / 2; i >= 1; i /= 2) { 110 | S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i)); 111 | } 112 | 113 | if (lane_id == 0) { 114 | softmax_smem[warp_id] = S_max; 115 | } 116 | 117 | __syncthreads(); 118 | 119 | if (lane_id < warps_per_block) { 120 | S_max = softmax_smem[lane_id]; 121 | } else { 122 | S_max = -std::numeric_limits::max(); 123 | } 124 | 125 | #pragma unroll 126 | for (size_t i = warps_per_block / 2; i >= 1; i /= 2) { 127 | S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i)); 128 | } 129 | 130 | S_max = __shfl_sync(shfl_mask, S_max, 0); 131 | 132 | float exp_sum = 0.0; 133 | #pragma unroll 134 | for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) { 135 | S_smem[seq_k] -= S_max; 136 | S_smem[seq_k] = exp(S_smem[seq_k]); 137 | exp_sum += S_smem[seq_k]; 138 | } 139 | 140 | #pragma unroll 141 | for (size_t i = warp_size / 2; i >= 1; i /= 2) { 142 | exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i); 143 | } 144 | 145 | if (lane_id == 0) { 146 | softmax_smem[warp_id] = exp_sum; 147 | } 148 | 149 | __syncthreads(); 150 | 151 | if (lane_id < warps_per_block) { 152 | exp_sum = softmax_smem[lane_id]; 153 | } 154 | 155 | #pragma unroll 156 | for (size_t i = warps_per_block / 2; i >= 1; i /= 2) { 157 | exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i); 158 | } 159 | exp_sum = __shfl_sync(shfl_mask, exp_sum, 0); 160 | 161 | #pragma unroll 162 | for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) { 163 | S_smem[seq_k] /= exp_sum; 164 | } 165 | 166 | __syncthreads(); 167 | 168 | // O = P * V 169 | T RV[thread_vo_nums]; 170 | float RO[thread_vo_nums]; 171 | 172 | memset(RO, 0, sizeof(RO)); 173 | 174 | #pragma unroll 175 | for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k; 176 | base_seq_k += groups_per_block) { 177 | size_t seq_k = base_seq_k + group_id; 178 | 179 | if (seq_k < binfo.actual_seq_k) { 180 | #pragma unroll 181 | for (size_t i = 0; i < thread_copy_vo_iters; ++i) { 182 | *(int4 *)(&RV[i * thread_copy_elem_nums]) = 183 | *(int4 *)(&v_ptr[binfo.k_offset(seq_k, params.v_row_stride, params.v_head_stride, 184 | (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]); 185 | } 186 | 187 | #pragma unroll 188 | for (size_t i = 0; i < thread_vo_nums; ++i) { 189 | if constexpr (std::is_same_v) { 190 | RO[i] += (S_smem[seq_k] * __half2float(RV[i])); 191 | } else { 192 | RO[i] += (S_smem[seq_k] * __bfloat162float(RV[i])); 193 | } 194 | } 195 | } 196 | } 197 | 198 | #pragma unroll 199 | for (size_t i = 0; i < thread_vo_nums; ++i) { 200 | #pragma unroll 201 | for (size_t j = threads_per_group; j <= warp_size / 2; j *= 2) { 202 | RO[i] += __shfl_xor_sync(shfl_mask, RO[i], j); 203 | } 204 | } 205 | 206 | __syncthreads(); 207 | 208 | #pragma unroll 209 | for (size_t i = threadIdx.x; i < head_dim_v; i += threads_per_block) { 210 | S_smem[i] = 0.0; 211 | } 212 | 213 | __syncthreads(); 214 | 215 | if (lane_id < threads_per_group) { 216 | #pragma unroll 217 | for (size_t i = 0; i < thread_copy_vo_iters; ++i) { 218 | #pragma unroll 219 | for (size_t j = 0; j < thread_copy_elem_nums; ++j) { 220 | atomicAdd(S_smem + (i * threads_per_group + lane_id) * thread_copy_elem_nums + j, 221 | RO[i * thread_copy_elem_nums + j]); 222 | } 223 | } 224 | } 225 | 226 | __syncthreads(); 227 | 228 | #pragma unroll 229 | for (size_t i = threadIdx.x; i < head_dim_v; i += threads_per_block) { 230 | if constexpr (std::is_same_v) { 231 | o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2half(S_smem[i]); 232 | } else { 233 | o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2bfloat16(S_smem[i]); 234 | } 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /benchmarks/python/benchmark_decoding_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023. All Rights Reserved. 2 | # Author: Bruce-Lee-LY 3 | # Date: 21:14:13 on Tue, Oct 31, 2023 4 | # 5 | # Description: benchmark decoding attn using python api 6 | 7 | # !/usr/bin/python3 8 | # coding=utf-8 9 | 10 | from __future__ import print_function 11 | from __future__ import division 12 | from __future__ import absolute_import 13 | from __future__ import with_statement 14 | 15 | import optparse 16 | import time 17 | import torch 18 | from torch.nn import functional as F 19 | 20 | from decoding_attn import decoding_attn_fwd 21 | 22 | try: 23 | # flash decoding 24 | from flash_attn import flash_attn_with_kvcache 25 | except ImportError: 26 | flash_attn_with_kvcache = None 27 | 28 | try: 29 | from flashinfer import single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper 30 | except ImportError: 31 | single_decode_with_kv_cache = None 32 | BatchDecodeWithPagedKVCacheWrapper = None 33 | 34 | try: 35 | from flash_mla import get_mla_metadata, flash_mla_with_kvcache 36 | except ImportError: 37 | get_mla_metadata = None 38 | flash_mla_with_kvcache = None 39 | 40 | 41 | def get_cu_seq(seqs: torch.tensor) -> torch.tensor: 42 | """ 43 | Arguments: 44 | seqs: [batch], dtype torch.int32, sequence length of each batch. 45 | Return: 46 | cu_seq: [batch + 1], dtype torch.int32. The cumulative sequence lengths of the sequences 47 | in the batch. 48 | """ 49 | return F.pad(seqs.cumsum(dim=0, dtype=torch.int32), (1, 0)) 50 | 51 | 52 | def compute_flops_and_bandwidth(batch, seq_q, seq_k, head_q, head_k, dim, dim_v, time): 53 | throughput = (batch * seq_q * seq_k * head_q * (dim + dim_v) 54 | * 2 * 10**(-12)) / (time * 10**(-3)) 55 | bandwidth = ((batch * seq_q * head_q * dim + batch * seq_k * head_k * 56 | dim + batch * seq_q * head_q * dim_v) * 2 * 10**(-9)) / (time * 10**(-3)) 57 | return throughput, bandwidth 58 | 59 | 60 | def benchmark_flash_attn(q, k, v, warmup_iterations=1, profiling_iterations=10): 61 | batch = q.shape[0] 62 | seq_q = q.shape[1] 63 | head_q = q.shape[2] 64 | dim = q.shape[3] 65 | seq_k = k.shape[1] 66 | head_k = k.shape[2] 67 | dim_v = v.shape[3] 68 | 69 | start = torch.cuda.Event(enable_timing=True) 70 | end = torch.cuda.Event(enable_timing=True) 71 | 72 | # warm up 73 | for _ in range(warmup_iterations): 74 | output = flash_attn_with_kvcache(q, k, v) 75 | torch.cuda.synchronize() 76 | # print(f"Flash-Decoding output: {output}") 77 | 78 | start.record() 79 | for _ in range(profiling_iterations): 80 | __ = flash_attn_with_kvcache(q, k, v) 81 | end.record() 82 | torch.cuda.synchronize() 83 | elapsed_time = start.elapsed_time(end) / profiling_iterations 84 | throughput, bandwidth = compute_flops_and_bandwidth( 85 | batch, seq_q, seq_k, head_q, head_k, dim, dim_v, elapsed_time) 86 | print("Flash-Decoding {}-{} profiling time: {:.4f} ms, throughput: {:.4f} TFLOPS, bandwidth: {:.3f} GB/s".format( 87 | batch, seq_k, elapsed_time, throughput, bandwidth)) 88 | 89 | 90 | def benchmark_flashinfer_single(q, k, v, warmup_iterations=1, profiling_iterations=10): 91 | batch = 1 92 | seq_q = 1 93 | head_q = q.shape[0] 94 | dim = q.shape[1] 95 | seq_k = k.shape[0] // batch 96 | head_k = k.shape[1] 97 | dim_v = v.shape[2] 98 | 99 | start = torch.cuda.Event(enable_timing=True) 100 | end = torch.cuda.Event(enable_timing=True) 101 | 102 | # warm up 103 | for _ in range(warmup_iterations): 104 | output = single_decode_with_kv_cache(q, k, v) 105 | torch.cuda.synchronize() 106 | # print(f"FlashInfer-Single output: {output}") 107 | 108 | start.record() 109 | for _ in range(profiling_iterations): 110 | __ = single_decode_with_kv_cache(q, k, v) 111 | end.record() 112 | torch.cuda.synchronize() 113 | elapsed_time = start.elapsed_time(end) / profiling_iterations 114 | throughput, bandwidth = compute_flops_and_bandwidth( 115 | batch, seq_q, seq_k, head_q, head_k, dim, dim_v, elapsed_time) 116 | print("FlashInfer {}-{} profiling time: {:.4f} ms, throughput: {:.4f} TFLOPS, bandwidth: {:.3f} GB/s".format( 117 | batch, seq_k, elapsed_time, throughput, bandwidth)) 118 | 119 | 120 | def benchmark_flashinfer_batch(q, k, v, warmup_iterations=1, profiling_iterations=10): 121 | batch = q.shape[0] 122 | seq_q = 1 123 | head_q = q.shape[1] 124 | dim = q.shape[2] 125 | seq_k = k.shape[0] // batch 126 | head_k = k.shape[1] 127 | dim_v = v.shape[2] 128 | 129 | # page_size: the page size of the paged kv cache 130 | page_size = 1 131 | num_pages_per_seq = (seq_k + page_size - 1) // page_size 132 | total_num_pages = num_pages_per_seq * batch 133 | # NHD: the last 3 dimensions are organized as [seq_k, head_k, dim] 134 | kv_layout = "NHD" 135 | k = k.unsqueeze(1).unsqueeze(2) 136 | v = v.unsqueeze(1).unsqueeze(2) 137 | # kv_data: [total_num_pages, 2, page_size, head_k, dim] 138 | kv_data = torch.cat((k, v), 1) 139 | kv_indptr = torch.arange(0, batch + 1, dtype=torch.int32, 140 | device=torch.device('cuda')) * num_pages_per_seq 141 | kv_indices = torch.arange( 142 | 0, total_num_pages, dtype=torch.int32, device=torch.device('cuda')) 143 | kv_last_page_len = torch.full( 144 | (batch,), (seq_k - 1) % page_size + 1, dtype=torch.int32, device=torch.device('cuda')) 145 | # the device of the workspace buffer should be the same as the device of the input tensors 146 | # in the split-k algorithm 147 | workspace_buffer = torch.empty( 148 | batch * seq_q * head_q * dim, dtype=torch.float32, device=torch.device('cuda')) 149 | wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) 150 | wrapper.plan(kv_indptr, kv_indices, kv_last_page_len, head_q, 151 | head_k, dim, page_size, data_type=k.dtype, q_data_type=q.dtype) 152 | 153 | start = torch.cuda.Event(enable_timing=True) 154 | end = torch.cuda.Event(enable_timing=True) 155 | 156 | # warm up 157 | for _ in range(warmup_iterations): 158 | output = wrapper.run(q, kv_data) 159 | torch.cuda.synchronize() 160 | # print(f"FlashInfer-Batch output: {output}") 161 | 162 | start.record() 163 | for _ in range(profiling_iterations): 164 | __ = wrapper.run(q, kv_data) 165 | end.record() 166 | torch.cuda.synchronize() 167 | elapsed_time = start.elapsed_time(end) / profiling_iterations 168 | throughput, bandwidth = compute_flops_and_bandwidth( 169 | batch, seq_q, seq_k, head_q, head_k, dim, dim_v, elapsed_time) 170 | print("FlashInfer {}-{} profiling time: {:.4f} ms, throughput: {:.4f} TFLOPS, bandwidth: {:.3f} GB/s".format( 171 | batch, seq_k, elapsed_time, throughput, bandwidth)) 172 | 173 | 174 | def benchmark_flashmla(q, k, dim_v, warmup_iterations=1, profiling_iterations=10): 175 | batch = q.shape[0] 176 | seq_q = q.shape[1] 177 | head_q = q.shape[2] 178 | dim = q.shape[3] 179 | seq_k = k.shape[1] 180 | head_k = k.shape[2] 181 | 182 | cache_seqlens = torch.full( 183 | (batch,), seq_k, dtype=torch.int32, device=torch.device('cuda')) 184 | block_size = 64 185 | block_table = torch.arange(batch * seq_k // block_size, dtype=torch.int32, 186 | device=torch.device('cuda')).view(batch, seq_k // block_size) 187 | blocked_k = k.reshape(block_table.numel(), block_size, head_k, dim) 188 | for i in range(batch): 189 | blocked_k.view(batch, seq_k, head_k, dim)[i, cache_seqlens[i].item():] = ( 190 | float("nan") 191 | ) 192 | 193 | tile_scheduler_metadata, num_splits = get_mla_metadata( 194 | cache_seqlens, seq_q * head_q // head_k, head_k 195 | ) 196 | 197 | start = torch.cuda.Event(enable_timing=True) 198 | end = torch.cuda.Event(enable_timing=True) 199 | 200 | # warm up 201 | for _ in range(warmup_iterations): 202 | output, _ = flash_mla_with_kvcache( 203 | q, blocked_k, block_table, cache_seqlens, dim_v, tile_scheduler_metadata, num_splits, causal=True) 204 | torch.cuda.synchronize() 205 | # print(f"FlashMLA output: {output}") 206 | 207 | start.record() 208 | for _ in range(profiling_iterations): 209 | __, ___ = flash_mla_with_kvcache( 210 | q, blocked_k, block_table, cache_seqlens, dim_v, tile_scheduler_metadata, num_splits, causal=True) 211 | end.record() 212 | torch.cuda.synchronize() 213 | elapsed_time = start.elapsed_time(end) / profiling_iterations 214 | throughput, bandwidth = compute_flops_and_bandwidth( 215 | batch, seq_q, seq_k, head_q, head_k, dim, dim_v, elapsed_time) 216 | print("FlashMLA {}-{} profiling time: {:.4f} ms, throughput: {:.4f} TFLOPS, bandwidth: {:.3f} GB/s".format( 217 | batch, seq_k, elapsed_time, throughput, bandwidth)) 218 | 219 | 220 | def benchmark_decoding_attn(q, k, v, dim_v, is_alibi, warmup_iterations=1, profiling_iterations=10): 221 | batch = q.shape[0] 222 | seq_q = 1 223 | head_q = q.shape[1] 224 | dim = q.shape[2] 225 | seq_k = k.shape[0] // batch 226 | head_k = k.shape[1] 227 | 228 | cu_seq_k = get_cu_seq(torch.full( 229 | (batch,), seq_k, dtype=torch.int32, device=torch.device('cuda'))) 230 | 231 | start = torch.cuda.Event(enable_timing=True) 232 | end = torch.cuda.Event(enable_timing=True) 233 | 234 | # warm up 235 | for _ in range(warmup_iterations): 236 | output = decoding_attn_fwd(q, k, v, cu_seq_k, seq_k, dim_v, is_alibi) 237 | torch.cuda.synchronize() 238 | # print(f"Decoding-Attention output: {output}") 239 | 240 | start.record() 241 | for _ in range(profiling_iterations): 242 | __ = decoding_attn_fwd(q, k, v, cu_seq_k, seq_k, dim_v, is_alibi) 243 | end.record() 244 | torch.cuda.synchronize() 245 | elapsed_time = start.elapsed_time(end) / profiling_iterations 246 | throughput, bandwidth = compute_flops_and_bandwidth( 247 | batch, seq_q, seq_k, head_q, head_k, dim, dim_v, elapsed_time) 248 | print("Decoding-Attention {}-{} profiling time: {:.4f} ms, throughput: {:.4f} TFLOPS, bandwidth: {:.3f} GB/s".format( 249 | batch, seq_k, elapsed_time, throughput, bandwidth)) 250 | 251 | 252 | def benchmark_forward(batch, seq_q, seq_k, head_q, head_k, dim, dim_v, is_alibi, is_bf16, warmup_iterations=1, profiling_iterations=10): 253 | torch.cuda.empty_cache() 254 | 255 | dtype = torch.bfloat16 if is_bf16 else torch.float16 256 | total_q = batch * seq_q 257 | total_k = batch * seq_k 258 | 259 | q = torch.randn(total_q, head_q, dim, 260 | device=torch.device('cuda'), dtype=dtype) 261 | k = torch.randn(total_k, head_k, dim, 262 | device=torch.device('cuda'), dtype=dtype) 263 | if dim == dim_v: 264 | v = torch.randn(total_k, head_k, dim_v, 265 | device=torch.device('cuda'), dtype=dtype) 266 | else: 267 | v = None 268 | 269 | if dim == dim_v and flash_attn_with_kvcache is not None: 270 | q4 = q.reshape(batch, seq_q, head_q, dim) 271 | k4 = k.reshape(batch, seq_k, head_k, dim) 272 | v4 = v.reshape(batch, seq_k, head_k, dim_v) 273 | 274 | time.sleep(0.1) 275 | benchmark_flash_attn(q4, k4, v4, warmup_iterations, 276 | profiling_iterations) 277 | 278 | if batch == 1 and dim == dim_v and single_decode_with_kv_cache is not None: 279 | q2 = q.reshape(head_q, dim) 280 | 281 | time.sleep(0.1) 282 | benchmark_flashinfer_single( 283 | q2, k, v, warmup_iterations, profiling_iterations) 284 | 285 | if batch > 1 and dim == dim_v and BatchDecodeWithPagedKVCacheWrapper is not None: 286 | time.sleep(0.1) 287 | benchmark_flashinfer_batch( 288 | q, k, v, warmup_iterations, profiling_iterations) 289 | 290 | if dim != dim_v and flash_mla_with_kvcache is not None: 291 | q4 = q.reshape(batch, seq_q, head_q, dim) 292 | k4 = k.reshape(batch, seq_k, head_k, dim) 293 | 294 | time.sleep(0.1) 295 | benchmark_flashmla(q4, k4, dim_v, warmup_iterations, 296 | profiling_iterations) 297 | 298 | time.sleep(0.1) 299 | benchmark_decoding_attn(q, k, v, dim_v, is_alibi, 300 | warmup_iterations, profiling_iterations) 301 | 302 | 303 | def benchmark_seq(head_q=32, head_k=32, dim=128, dim_v=128, is_alibi=False, is_bf16=False, warmup_iterations=1, profiling_iterations=10): 304 | print("------------------------------- Benchmark Seq -------------------------------") 305 | batch = 1 306 | seq_q = 1 307 | if dim == dim_v: 308 | seq_ks = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 309 | 2048, 3072, 4096, 5120, 6144, 7168, 8192] 310 | else: 311 | seq_ks = [256, 512, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192] 312 | 313 | for seq_k in seq_ks: 314 | benchmark_forward(batch, seq_q, seq_k, head_q, head_k, dim, dim_v, 315 | is_alibi, is_bf16, warmup_iterations, profiling_iterations) 316 | 317 | 318 | def benchmark_batch(head_q=32, head_k=32, dim=128, dim_v=128, is_alibi=False, is_bf16=False, warmup_iterations=1, profiling_iterations=10): 319 | print("------------------------------- Benchmark Batch -------------------------------") 320 | batchs = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024, 1536, 2048] 321 | seq_q = 1 322 | if dim == dim_v: 323 | seq_k = 128 324 | else: 325 | seq_k = 4096 326 | 327 | for batch in batchs: 328 | benchmark_forward(batch, seq_q, seq_k, head_q, head_k, dim, dim_v, 329 | is_alibi, is_bf16, warmup_iterations, profiling_iterations) 330 | 331 | 332 | def main(): 333 | usage = "python3 benchmark_decoding_attn.py --head_q 32 --head_k 32 --dim 128 --dim_v 128 --warmup_iterations 1 --profiling_iterations 10" 334 | parser = optparse.OptionParser(usage) 335 | parser.add_option("--head_q", dest="head_q", type="int", default="32") 336 | parser.add_option("--head_k", dest="head_k", type="int", default="32") 337 | parser.add_option("--dim", dest="dim", type="int", default="128") 338 | parser.add_option("--dim_v", dest="dim_v", type="int", default="128") 339 | parser.add_option("--is_alibi", action="store_true", 340 | dest="is_alibi", default=False) 341 | parser.add_option("--is_bf16", action="store_true", 342 | dest="is_bf16", default=False) 343 | parser.add_option("--warmup_iterations", 344 | dest="warmup_iterations", type="int", default="1") 345 | parser.add_option("--profiling_iterations", 346 | dest="profiling_iterations", type="int", default="10") 347 | 348 | options, args = parser.parse_args() 349 | head_q = options.head_q 350 | head_k = options.head_k 351 | dim = options.dim 352 | dim_v = options.dim_v 353 | is_alibi = options.is_alibi 354 | is_bf16 = options.is_bf16 355 | warmup_iterations = options.warmup_iterations 356 | profiling_iterations = options.profiling_iterations 357 | 358 | print( 359 | f"Benchmark Decoding Attention: head q: {head_q}, head k: {head_k}, dim: {dim}, dim v: {dim_v}, is alibi: {is_alibi}, is bf16: {is_bf16}, warmup iterations: {warmup_iterations}, profiling iterations: {profiling_iterations}") 360 | 361 | benchmark_seq(head_q, head_k, dim, dim_v, is_alibi, is_bf16, 362 | warmup_iterations, profiling_iterations) 363 | benchmark_batch(head_q, head_k, dim, dim_v, is_alibi, 364 | is_bf16, warmup_iterations, profiling_iterations) 365 | 366 | 367 | if __name__ == "__main__": 368 | main() 369 | -------------------------------------------------------------------------------- /csrc/common/tester.h: -------------------------------------------------------------------------------- 1 | // Copyright 2023. All Rights Reserved. 2 | // Author: Bruce-Lee-LY 3 | // Date: 21:08:30 on Sun, Aug 27, 2023 4 | // 5 | // Description: tester 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "cuda_timer.h" 12 | #include "tensor.h" 13 | 14 | template 15 | class Tester { 16 | public: 17 | explicit Tester(size_t batch = 2, size_t seq_q = 1, size_t seq_k = 256, size_t head_q = 32, size_t head_k = 32, 18 | size_t dim = 128, size_t dim_v = 128, bool is_causal = false, bool is_alibi = false, 19 | cudaStream_t stream = nullptr, size_t warmup_iterations = 1, size_t profiling_iterations = 10, 20 | size_t sleep_duration = 100, bool enable_check = false) 21 | : m_batch(batch), 22 | m_seq_q(seq_q), 23 | m_seq_k(seq_k), 24 | m_head_q(head_q), 25 | m_head_k(head_k), 26 | m_dim(dim), 27 | m_dim_v(dim_v), 28 | m_is_causal(is_causal), 29 | m_is_alibi(is_alibi), 30 | m_stream(stream), 31 | m_warmup_iterations(warmup_iterations), 32 | m_profiling_iterations(profiling_iterations), 33 | m_sleep_duration(sleep_duration), 34 | m_enable_check(enable_check) { 35 | DA_CHECK_GT(m_batch, 0); 36 | DA_CHECK_GT(m_seq_q, 0); 37 | DA_CHECK_GT(m_seq_k, 0); 38 | DA_CHECK_GT(m_head_q, 0); 39 | DA_CHECK_GT(m_head_k, 0); 40 | DA_CHECK_EQ(m_head_q % m_head_k, 0); 41 | DA_CHECK_GT(m_dim, 0); 42 | DA_CHECK_LE(m_dim, 576); 43 | DA_CHECK_GT(m_dim_v, 0); 44 | DA_CHECK_LE(m_dim_v, 512); 45 | DA_CHECK_GT(m_warmup_iterations, 0); 46 | DA_CHECK_GT(m_profiling_iterations, 0); 47 | DA_CHECK_GT(m_sleep_duration, 0); 48 | 49 | if constexpr (std::is_same_v) { 50 | m_is_bf16 = true; 51 | } 52 | 53 | m_total_q = m_batch * m_seq_q; 54 | m_total_k = m_batch * m_seq_k; 55 | 56 | m_Q = std::make_shared>(std::vector{m_total_q, m_head_q, m_dim}, "Tensor Q"); 57 | DA_CHECK(m_Q); 58 | m_Q_dev_ptr = reinterpret_cast(m_Q->getDevPtr()); 59 | DA_CHECK(m_Q_dev_ptr); 60 | m_K = std::make_shared>(std::vector{m_total_k, m_head_k, m_dim}, "Tensor K"); 61 | DA_CHECK(m_K); 62 | m_K_dev_ptr = reinterpret_cast(m_K->getDevPtr()); 63 | DA_CHECK(m_K_dev_ptr); 64 | if (m_dim == m_dim_v) { 65 | m_V = std::make_shared>(std::vector{m_total_k, m_head_k, m_dim_v}, "Tensor V"); 66 | DA_CHECK(m_V); 67 | m_V_dev_ptr = reinterpret_cast(m_V->getDevPtr()); 68 | DA_CHECK(m_V_dev_ptr); 69 | } 70 | m_O = std::make_shared>(std::vector{m_total_q, m_head_q, m_dim_v}, "Tensor O"); 71 | DA_CHECK(m_O); 72 | m_O_dev_ptr = reinterpret_cast(m_O->getDevPtr()); 73 | DA_CHECK(m_O_dev_ptr); 74 | m_base = std::make_shared>(std::vector{m_total_q, m_head_q, m_dim_v}, "Tensor Base"); 75 | DA_CHECK(m_base); 76 | 77 | m_cu_seq_k = std::make_shared>(std::vector{m_batch + 1}, "Tensor cu_seq_k"); 78 | DA_CHECK(m_cu_seq_k); 79 | m_cu_seq_k_dev_ptr = m_cu_seq_k->getDevPtr(); 80 | DA_CHECK(m_cu_seq_k_dev_ptr); 81 | 82 | get_cu_seq(m_cu_seq_k.get(), m_seq_k); 83 | m_cu_seq_k->moveToDevice(); 84 | 85 | m_cuda_timer = std::make_shared(m_stream); 86 | DA_CHECK(m_cuda_timer); 87 | 88 | if (m_enable_check) { 89 | clock_t start = clock(); 90 | attn_cpu(m_Q, m_K, m_V, m_base, m_cu_seq_k, m_seq_k, m_dim_v, m_is_causal, m_is_alibi); 91 | clock_t end = clock(); 92 | DLOG("MHA CPU use: %.3f ms", static_cast(end - start) / (CLOCKS_PER_SEC * 1e-3)); 93 | } 94 | } 95 | 96 | ~Tester() {} 97 | 98 | template 99 | void evaluate(Func &&dmha, const std::string &name) { 100 | DLOG("----------------- Evaluating %s -----------------", name.c_str()); 101 | usleep(m_sleep_duration * 1000); 102 | m_O->tearUp(m_base.get()); 103 | 104 | // warm up 105 | m_cuda_timer->start(); 106 | for (size_t i = 0; i < m_warmup_iterations; ++i) { 107 | dmha(m_Q_dev_ptr, m_K_dev_ptr, m_V_dev_ptr, m_O_dev_ptr, m_cu_seq_k_dev_ptr, m_seq_k, m_batch, m_head_q, 108 | m_head_k, m_dim, m_dim_v, m_is_alibi, m_is_bf16, m_stream); 109 | } 110 | m_warmup_time = static_cast(m_cuda_timer->end()) / static_cast(m_warmup_iterations); 111 | DLOG("Warm up time: %.3f ms", m_warmup_time); 112 | 113 | if (m_enable_check) { 114 | m_O->moveToHost(); 115 | m_O->checkValue(m_base.get()); 116 | } 117 | 118 | profile(std::forward(dmha), name); 119 | } 120 | 121 | private: 122 | void get_cu_seq(Tensor *cu_seq, size_t seq) { 123 | size_t batch = cu_seq->getShape()[0] - 1; 124 | int *cu_seq_ptr = cu_seq->getHostPtr(); 125 | 126 | for (size_t i = 0; i < batch + 1; ++i) { 127 | cu_seq_ptr[i] = i * seq; 128 | } 129 | } 130 | 131 | // sopprt MHA/MQA/GQA/MLA 132 | // MLA: K == kv_c_and_k_pe_cache, V == nullptr 133 | void attn_cpu(std::shared_ptr> Q, std::shared_ptr> K, std::shared_ptr> V, 134 | std::shared_ptr> O, std::shared_ptr> cu_seq_k, size_t max_seq_k, size_t dim_v, 135 | bool is_causal, bool is_alibi) { 136 | size_t batch = cu_seq_k->getShape()[0] - 1; 137 | const size_t seq_q = 1; 138 | size_t total_q = Q->getShape()[0]; 139 | size_t head_q = Q->getShape()[1]; 140 | size_t dim = Q->getShape()[2]; 141 | size_t head_k = K->getShape()[1]; 142 | size_t d_v = V ? V->getShape()[2] : dim_v; 143 | 144 | DA_CHECK_EQ(head_q % head_k, 0); 145 | const size_t head_ratio = head_q / head_k; 146 | 147 | T *q_ptr = Q->getHostPtr(); 148 | T *k_ptr = K->getHostPtr(); 149 | T *v_ptr = V ? V->getHostPtr() : K->getHostPtr(); 150 | T *o_ptr = O->getHostPtr(); 151 | 152 | int *cu_seq_k_ptr = cu_seq_k->getHostPtr(); 153 | 154 | // S = Q * K^T 155 | auto S = std::make_shared>(std::vector{total_q, head_q, max_seq_k}, "Tensor S"); 156 | DA_CHECK(S); 157 | float *s_ptr = S->getHostPtr(); 158 | for (size_t b = 0; b < batch; ++b) { 159 | size_t sum_seq_q = b; 160 | size_t sum_seq_k = static_cast(cu_seq_k_ptr[b]); 161 | size_t seq_k = static_cast(cu_seq_k_ptr[b + 1]) - sum_seq_k; 162 | for (size_t h = 0; h < head_q; ++h) { 163 | size_t h_k = h / head_ratio; 164 | for (size_t sq = 0; sq < seq_q; ++sq) { 165 | for (size_t sk = 0; sk < seq_k; ++sk) { 166 | float acc = 0.0; 167 | for (size_t d = 0; d < dim; ++d) { 168 | if constexpr (std::is_same_v) { 169 | acc += __half2float(q_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d]) * 170 | __half2float(k_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]); 171 | } else { 172 | acc += __bfloat162float(q_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d]) * 173 | __bfloat162float(k_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]); 174 | } 175 | } 176 | s_ptr[sum_seq_q * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = acc; 177 | } 178 | } 179 | } 180 | } 181 | 182 | // P = Softmax(S) 183 | auto P = std::make_shared>(std::vector{total_q, head_q, max_seq_k}, "Tensor P"); 184 | DA_CHECK(P); 185 | float *p_ptr = P->getHostPtr(); 186 | float scale = 1.0 / std::sqrt(dim); 187 | for (size_t b = 0; b < batch; ++b) { 188 | size_t sum_seq_q = b; 189 | size_t sum_seq_k = static_cast(cu_seq_k_ptr[b]); 190 | size_t seq_k = static_cast(cu_seq_k_ptr[b + 1]) - sum_seq_k; 191 | size_t row_shift = seq_k - seq_q; 192 | for (size_t h = 0; h < head_q; ++h) { 193 | float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0; 194 | for (size_t sq = 0; sq < seq_q; ++sq) { 195 | size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k; 196 | 197 | // Max(S) 198 | std::vector tmp_s(seq_k, 0.0); 199 | float max_s = -std::numeric_limits::max(); 200 | for (size_t sk = 0; sk < col_limit; ++sk) { 201 | tmp_s[sk] = s_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * scale; 202 | if (is_alibi && sk < sq + row_shift) { 203 | tmp_s[sk] += 204 | (h_slope * (static_cast(sk) - static_cast(sq) - static_cast(row_shift))); 205 | } 206 | max_s = std::max(max_s, tmp_s[sk]); 207 | } 208 | 209 | // Sum(S) 210 | float sum_s = 0.0; 211 | for (size_t sk = 0; sk < col_limit; ++sk) { 212 | tmp_s[sk] = std::exp(tmp_s[sk] - max_s); 213 | sum_s += tmp_s[sk]; 214 | } 215 | 216 | // Softmax(S) 217 | for (size_t sk = 0; sk < col_limit; ++sk) { 218 | p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = tmp_s[sk] / sum_s; 219 | } 220 | 221 | // Causal(S) 222 | if (is_causal) { 223 | for (size_t sk = col_limit; sk < seq_k; ++sk) { 224 | p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = 0.0; 225 | } 226 | } 227 | } 228 | } 229 | } 230 | 231 | // O = P * V 232 | for (size_t b = 0; b < batch; ++b) { 233 | size_t sum_seq_q = b; 234 | size_t sum_seq_k = static_cast(cu_seq_k_ptr[b]); 235 | size_t seq_k = static_cast(cu_seq_k_ptr[b + 1]) - sum_seq_k; 236 | for (size_t h = 0; h < head_q; ++h) { 237 | size_t h_k = h / head_ratio; 238 | for (size_t sq = 0; sq < seq_q; ++sq) { 239 | for (size_t d = 0; d < d_v; ++d) { 240 | float acc = 0.0; 241 | for (size_t sk = 0; sk < seq_k; ++sk) { 242 | if constexpr (std::is_same_v) { 243 | acc += p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * 244 | __half2float(v_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]); 245 | } else { 246 | acc += p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * 247 | __bfloat162float(v_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]); 248 | } 249 | } 250 | if constexpr (std::is_same_v) { 251 | o_ptr[(sum_seq_q + sq) * (head_q * d_v) + h * d_v + d] = __float2half(acc); 252 | } else { 253 | o_ptr[(sum_seq_q + sq) * (head_q * d_v) + h * d_v + d] = __float2bfloat16(acc); 254 | } 255 | } 256 | } 257 | } 258 | } 259 | } 260 | 261 | template 262 | void profile(Func &&dmha, const std::string &name) { 263 | m_cuda_timer->start(); 264 | for (size_t i = 0; i < m_profiling_iterations; ++i) { 265 | dmha(m_Q_dev_ptr, m_K_dev_ptr, m_V_dev_ptr, m_O_dev_ptr, m_cu_seq_k_dev_ptr, m_seq_k, m_batch, m_head_q, 266 | m_head_k, m_dim, m_dim_v, m_is_alibi, m_is_bf16, m_stream); 267 | } 268 | m_profiling_time = static_cast(m_cuda_timer->end()) / static_cast(m_profiling_iterations); 269 | 270 | m_throughput = static_cast(m_batch * m_seq_q * m_seq_k * m_head_q * (m_dim + m_dim_v) * 2) * 1e-12 / 271 | (static_cast(m_profiling_time) * 1e-3); 272 | 273 | m_bandwidth = static_cast((m_batch * m_seq_q * m_head_q * m_dim + m_batch * m_seq_k * m_head_k * m_dim + 274 | m_batch * m_seq_q * m_head_q * m_dim_v) * 275 | 2) * 276 | 1e-9 / (static_cast(m_profiling_time) * 1e-3); 277 | if (m_is_causal) { 278 | m_throughput /= 2; 279 | m_bandwidth /= 2; 280 | } 281 | 282 | if ((std::abs(m_base_time) <= 1e-6) && (std::abs(m_base_throughput) <= 1e-6)) { 283 | m_base_time = m_profiling_time; 284 | m_base_throughput = m_throughput; 285 | m_base_bandwidth = m_bandwidth; 286 | } 287 | 288 | DLOG( 289 | "%s exit, profiling time: %.4f ms (%.2f%%), throughput: %.4f TFLOPS (%.2f%%), bandwidth: %.3f GB/s " 290 | "(%.2f%%)", 291 | name.c_str(), m_profiling_time, m_profiling_time / m_base_time * 100, m_throughput, 292 | m_throughput / m_base_throughput * 100, m_bandwidth, m_bandwidth / m_base_bandwidth * 100); 293 | } 294 | 295 | const size_t m_batch = 2; 296 | const size_t m_seq_q = 1; 297 | const size_t m_seq_k = 256; 298 | const size_t m_head_q = 32; 299 | const size_t m_head_k = 32; 300 | const size_t m_dim = 128; 301 | const size_t m_dim_v = 128; 302 | const bool m_is_causal = false; 303 | const bool m_is_alibi = false; 304 | bool m_is_bf16 = false; 305 | const cudaStream_t m_stream = nullptr; 306 | 307 | const size_t m_warmup_iterations = 1; 308 | const size_t m_profiling_iterations = 10; 309 | const size_t m_sleep_duration = 100; 310 | const bool m_enable_check = false; 311 | 312 | size_t m_total_q = 0; 313 | size_t m_total_k = 0; 314 | 315 | std::shared_ptr> m_Q = nullptr; // total_q * head_q * dim 316 | std::shared_ptr> m_K = nullptr; // total_k * head_k * dim 317 | std::shared_ptr> m_V = nullptr; // total_k * head_k * dim_v 318 | std::shared_ptr> m_O = nullptr; // total_q * head_q * dim_v 319 | std::shared_ptr> m_base = 320 | nullptr; // total_q * head_q * dim_v, base result, init tensor O before each dmha 321 | 322 | void *m_Q_dev_ptr = nullptr; // total_q * head_q * dim 323 | void *m_K_dev_ptr = nullptr; // total_k * head_k * dim 324 | void *m_V_dev_ptr = nullptr; // total_k * head_k * dim_v 325 | void *m_O_dev_ptr = nullptr; // total_q * head_q * dim_v 326 | 327 | std::shared_ptr> m_cu_seq_k = nullptr; // batch + 1 328 | 329 | int *m_cu_seq_k_dev_ptr = nullptr; // batch + 1 330 | 331 | std::shared_ptr m_cuda_timer = nullptr; 332 | 333 | double m_warmup_time = 0.0; 334 | double m_profiling_time = 0.0; 335 | double m_throughput = 0.0; 336 | double m_bandwidth = 0.0; 337 | double m_base_time = 0.0; // decoding attn op 338 | double m_base_throughput = 0.0; // decoding attn op 339 | double m_base_bandwidth = 0.0; // decoding attn op 340 | 341 | DA_DISALLOW_COPY_AND_ASSIGN(Tester); 342 | }; 343 | --------------------------------------------------------------------------------