├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── benchmark └── bench_single_decode.ipynb ├── bit_decode ├── __init__.py └── bit_decode_interface.py ├── csrc └── bit_decode │ ├── CMakeLists.txt │ ├── decode_api.cpp │ └── src │ ├── bench_single_packdecode.cu │ ├── flash_api.h │ ├── flash_fwd_kernel.h │ ├── flash_fwd_launch_template.h │ ├── genfile │ ├── flash_fwd_hdim128_fp16_sm80.cu │ ├── flash_fwd_split_hdim128_fp16_sm80_2bit.cu │ ├── flash_fwd_split_hdim128_fp16_sm80_4bit.cu │ ├── flash_qpack_hdim128_fp16_sm80_2bit.cu │ └── flash_qpack_hdim128_fp16_sm80_4bit.cu │ ├── include │ ├── alibi.h │ ├── block_info.h │ ├── dequantize.h │ ├── dropout.h │ ├── flash.h │ ├── kernel_traits.h │ ├── mask.h │ ├── philox.cuh │ ├── qpack.h │ ├── rotary.h │ ├── softmax.h │ ├── static_switch.h │ └── utils.h │ ├── test_batch_packdecode.cu │ └── test_single_packdecode.cu ├── imgs ├── 4090.png ├── a100.png ├── overview.png └── scheme.png ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | env/ 9 | build/ 10 | dist/ 11 | *.log 12 | *.egg-info 13 | 14 | # pyenv 15 | .python-version 16 | 17 | # dotenv 18 | .env 19 | 20 | # virtualenv 21 | .venv/ 22 | venv/ 23 | ENV/ 24 | 25 | # VSCode settings 26 | .vscode 27 | 28 | # IDEA files 29 | .idea 30 | 31 | # OSX dir files 32 | .DS_Store 33 | 34 | # Sublime Text settings 35 | *.sublime-workspace 36 | *.sublime-project 37 | 38 | # PyTorch Source Files 39 | kernels/3rdparty/libtorch/ 40 | 41 | hf-models/ 42 | *.npy 43 | *.pt 44 | 45 | pred/ 46 | pred_e/ 47 | logs/ 48 | 49 | *.so 50 | 51 | libtorch/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "libs/cutlass"] 2 | path = libs/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Dayou Du 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BitDecoding 2 | [![arXiv](https://img.shields.io/badge/arXiv-2410.13276-b31b1b.svg)](https://arxiv.org/abs/2503.18773) 3 | [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) 4 | 5 | BitDecoding is a high-performance, GPU-optimized system 6 | designed to accelerate long-context LLMs decoding with a low-bit KV 7 | cache. Achieve **3-9x speedup** than Flash Attention v2. 8 | ![overview](imgs/overview.png) 9 | ![scheme](imgs/scheme.png) 10 | 11 | ## Benchmark 12 | * Kernel Performance in RTX4090 13 | ![overview](imgs/4090.png) 14 | * Kernel Performance in A100 15 | ![overview](imgs/a100.png) 16 | 17 | ## Installation 18 | ``` 19 | git clone --recursive https://github.com/DD-DuDa/BitDecoding.git 20 | conda create -n bitdecode python=3.10 21 | conda activate bitdecode 22 | pip install -r requirements.txt 23 | python setup.py install 24 | ``` 25 | 26 | ## Quick Start 27 | 1. See benchmark/bench_single_decode.ipynb 28 | 2. (Optional) Play with libtorch c++ 29 | ``` 30 | # download libtorch 31 | 32 | cd BitDecoding/csrc/bit_decode 33 | mkdir build && cd build 34 | cmake -DCMAKE_PREFIX_PATH= .. 35 | make -j12 36 | ``` 37 | 38 | ## Release Progress 39 | 40 | - [x] Page Implementation 41 | - [ ] End-2-end LLMs Inference 42 | - [ ] Hopper Implementation 43 | 44 | 45 | ## Citation 46 | If you find BitDecoding useful or want to use in your projects, please kindly cite our paper: 47 | ``` 48 | @misc{du2025bitdecodingunlockingtensorcores, 49 | title={BitDecoding: Unlocking Tensor Cores for Long-Context LLMs Decoding with Low-Bit KV Cache}, 50 | author={Dayou Du and Shijie Cao and Jianyi Cheng and Ting Cao and Mao Yang}, 51 | year={2025}, 52 | eprint={2503.18773}, 53 | archivePrefix={arXiv}, 54 | primaryClass={cs.AR}, 55 | url={https://arxiv.org/abs/2503.18773}, 56 | } 57 | ``` 58 | 59 | ## Acknowledgement 60 | BitDecoding is inspired by many open-source libraries, including (but not limited to) [flash-attention](https://github.com/Dao-AILab/flash-attention/tree/main), [flute](https://github.com/HanGuo97/flute), [Atom](https://github.com/efeslab/Atom), [omniserve](https://github.com/mit-han-lab/omniserve), [KIVI](https://github.com/jy-yuan/KIVI). 61 | -------------------------------------------------------------------------------- /bit_decode/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0.post1" 2 | 3 | from bit_decode.bit_decode_interface import ( 4 | kvcache_pack_int, 5 | fwd_kvcache_int 6 | ) 7 | -------------------------------------------------------------------------------- /bit_decode/bit_decode_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Dayou Du. 2 | 3 | from typing import Optional, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import bit_decode_cuda as bit_decode_cuda 9 | 10 | def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torch.Tensor, 11 | v_cache: torch.Tensor, v_pack: torch.Tensor, v_params: torch.Tensor, 12 | opt_block_table: Optional[torch.Tensor] = None, 13 | cu_seqlens_k: torch.Tensor = None, 14 | seqlen_k: int = 0, 15 | quant_mode: str = "k-channel", 16 | group_size: int = 128, 17 | num_bits: int = 4): 18 | 19 | batch_size, seqlen_k, nheads_k, d = k_cache.shape 20 | 21 | K_unpad = k_cache.reshape(batch_size * seqlen_k, nheads_k, d) 22 | V_unpad = v_cache.reshape(batch_size * seqlen_k, nheads_k, d) 23 | 24 | if num_bits == 4: 25 | bit_decode_cuda.kvcache_pack_i4(K_unpad, k_pack, k_params, 26 | V_unpad, v_pack, v_params, 27 | opt_block_table, 28 | cu_seqlens_k, 29 | seqlen_k, 30 | quant_mode, 31 | group_size 32 | ) 33 | else: 34 | bit_decode_cuda.kvcache_pack_i2(K_unpad, k_pack, k_params, 35 | V_unpad, v_pack, v_params, 36 | opt_block_table, 37 | cu_seqlens_k, 38 | seqlen_k, 39 | quant_mode, 40 | group_size 41 | ) 42 | 43 | def fwd_kvcache_int(q: torch.Tensor, 44 | k_pack: torch.Tensor, k_params: torch.Tensor, 45 | v_pack: torch.Tensor, v_params: torch.Tensor, 46 | opt_block_table: Optional[torch.Tensor] = None, 47 | softmax_scale: float = 1.0, 48 | quant_mode: str = "k-channel", 49 | group_size: int = 128, 50 | num_bits: int = 4): 51 | 52 | if num_bits == 4: 53 | out_bit = bit_decode_cuda.fwd_kvcache_i4( 54 | q, 55 | k_pack, k_params, 56 | v_pack, v_params, 57 | opt_block_table, 58 | softmax_scale, 59 | quant_mode, 60 | group_size, 61 | False, # is_causal 62 | -1, # window_size_left 63 | -1, # window_size_right 64 | 0.0, # softcap 65 | True, # is_rotary_interleaved 66 | 0 # num_splits 67 | ) 68 | else: 69 | out_bit = bit_decode_cuda.fwd_kvcache_i2( 70 | q, 71 | k_pack, k_params, 72 | v_pack, v_params, 73 | opt_block_table, 74 | softmax_scale, 75 | quant_mode, 76 | group_size, 77 | False, # Added 78 | -1, # Added 79 | -1, # Added 80 | 0.0, # Added 81 | True, # Added 82 | 0 # Added 83 | ) 84 | 85 | 86 | return out_bit 87 | -------------------------------------------------------------------------------- /csrc/bit_decode/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22.1) 2 | project(bitdecoding CUDA CXX) 3 | 4 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CUDA_STANDARD 17) 7 | set(CMAKE_CUDA_ARCHITECTURES 80) 8 | 9 | set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/../../libs/cutlass/include) 10 | 11 | # Enable ccache if available 12 | find_program(CCACHE_PROGRAM ccache) 13 | if(CCACHE_PROGRAM) 14 | set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") 15 | set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") 16 | endif() 17 | 18 | find_package(Torch REQUIRED) 19 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 20 | 21 | message(STATUS "Compile testing packdecode kernel.") 22 | add_executable(test_single_packdecode 23 | ${PROJECT_SOURCE_DIR}/src/test_single_packdecode.cu 24 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_sm80.cu 25 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu 26 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu 27 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu 28 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu 29 | ) 30 | target_link_libraries(test_single_packdecode "${TORCH_LIBRARIES}") 31 | target_include_directories(test_single_packdecode PRIVATE ${INCLUDE_DIR}) 32 | target_compile_options(test_single_packdecode PRIVATE $<$:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>) 33 | 34 | message(STATUS "Compile testing packdecode kernel.") 35 | add_executable(test_batch_packdecode 36 | ${PROJECT_SOURCE_DIR}/src/test_batch_packdecode.cu 37 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_sm80.cu 38 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu 39 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu 40 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu 41 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu 42 | ) 43 | target_link_libraries(test_batch_packdecode "${TORCH_LIBRARIES}") 44 | target_include_directories(test_batch_packdecode PRIVATE ${INCLUDE_DIR}) 45 | target_compile_options(test_batch_packdecode PRIVATE $<$:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>) 46 | 47 | message(STATUS "Compile benchmarking kernel.") 48 | add_executable(bench_single_packdecode 49 | ${PROJECT_SOURCE_DIR}/src/bench_single_packdecode.cu 50 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_sm80.cu 51 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu 52 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu 53 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu 54 | ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu 55 | ) 56 | target_link_libraries(bench_single_packdecode "${TORCH_LIBRARIES}") 57 | target_include_directories(bench_single_packdecode PRIVATE ${INCLUDE_DIR}) 58 | target_compile_options(bench_single_packdecode PRIVATE $<$:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>) 59 | 60 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/bench_single_packdecode.cu: -------------------------------------------------------------------------------- 1 | #include "flash_api.h" 2 | #include 3 | #include 4 | 5 | 6 | template 7 | double TestDecodingKernelPerformance(int seqlen_kv, const std::string& quant_mode, const int group_size, const int repeat) { 8 | const int bs = 1; 9 | const int seqlen_q = 1; 10 | const int pack_nums = 16 / num_bits; 11 | 12 | torch::Tensor Q_host = torch::rand({bs, seqlen_q, num_heads, head_dim}, torch::dtype(torch::kHalf)); 13 | torch::Tensor K_host = torch::ones({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 14 | torch::Tensor V_host = torch::ones({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 15 | 16 | torch::Tensor Q_device = Q_host.to(torch::kCUDA); 17 | torch::Tensor K_device = K_host.to(torch::kCUDA); 18 | torch::Tensor V_device = V_host.to(torch::kCUDA); 19 | 20 | at::Tensor k_pack, k_params, v_pack, v_params; 21 | if (quant_mode == "k-channel") { 22 | k_pack = torch::empty({bs, seqlen_kv / pack_nums, num_heads_kv, head_dim}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 23 | k_params = torch::empty({bs, seqlen_kv / group_size, num_heads_kv, head_dim}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 24 | } else { 25 | k_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 26 | k_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 27 | } 28 | 29 | v_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 30 | v_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 31 | 32 | // Convert K, V to unpadded format 33 | torch::Tensor K_unpad = K_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); 34 | torch::Tensor V_unpad = V_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); 35 | 36 | auto cu_seqlens_k = torch::arange(0, (bs + 1) * seqlen_kv, seqlen_kv, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); 37 | std::optional opt_block_table = std::nullopt; 38 | 39 | kvcache_qpack( 40 | K_unpad, k_pack, k_params, 41 | V_unpad, v_pack, v_params, 42 | opt_block_table, 43 | cu_seqlens_k, 44 | seqlen_kv, 45 | quant_mode, 46 | group_size 47 | ); 48 | 49 | at::Tensor K_new_host, V_new_host, K_new_device, V_new_device, seqlens_k; 50 | 51 | const float sm_scale = 1 / std::sqrt(float(head_dim)); 52 | // Warm up 53 | for (int i = 0; i < 5; ++i) 54 | mha_fwd_kvcache(Q_device, 55 | k_pack, k_params, 56 | v_pack, v_params, 57 | opt_block_table, 58 | sm_scale, 59 | quant_mode, 60 | group_size); 61 | 62 | // Benchmark 63 | cudaEvent_t start, end; 64 | cudaEventCreate(&start); 65 | cudaEventCreate(&end); 66 | cudaEventRecord(start); 67 | for (int i = 0; i < repeat; i++) { 68 | mha_fwd_kvcache(Q_device, 69 | k_pack, k_params, 70 | v_pack, v_params, 71 | opt_block_table, 72 | sm_scale, 73 | quant_mode, 74 | group_size); 75 | } 76 | cudaEventRecord(end); 77 | cudaEventSynchronize(end); 78 | 79 | float msec, sec; 80 | cudaEventElapsedTime(&msec, start, end); 81 | msec = msec / repeat; 82 | 83 | return msec; 84 | } 85 | 86 | int main() { 87 | const int num_heads = 32; 88 | const int num_heads_kv = 32; 89 | const int head_dim = 128; 90 | 91 | const std::string quant_mode = "k-channel"; 92 | const int num_bits = 4; 93 | const int group_size = 128; 94 | 95 | const int test_num = 10; 96 | int len_list[test_num]; 97 | len_list[0] = 1024; 98 | for (int i = 1; i < test_num; i++) { 99 | len_list[i] = len_list[i - 1] * 2; 100 | } 101 | 102 | const int outer_repeat = 3, inner_repeat = 3; 103 | printf("\n######## Benchmark single decode ########\n"); 104 | for (int j = 0; j < test_num; j++) { 105 | 106 | int seqlen_kv = len_list[j]; 107 | double max_msec = 0.0; 108 | double min_msec = DBL_MAX; 109 | double total_msec = 0.0; 110 | 111 | for (int k = 0; k < outer_repeat; k++) { 112 | double this_sec = TestDecodingKernelPerformance(seqlen_kv, quant_mode, group_size, inner_repeat); 113 | max_msec = max(max_msec, this_sec); 114 | min_msec = min(min_msec, this_sec); 115 | total_msec += this_sec; 116 | } 117 | 118 | double avg_msec = total_msec / outer_repeat; 119 | printf("seqlen_kv num_heads head_dim = %6d %6d %6d, ", seqlen_kv, num_heads, head_dim); 120 | printf("Time = %12.8lf %12.8lf %12.8lf ms, \n", min_msec, avg_msec, max_msec); 121 | } 122 | 123 | return 0; 124 | } -------------------------------------------------------------------------------- /csrc/bit_decode/src/flash_fwd_launch_template.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "include/static_switch.h" 10 | #include "include/flash.h" 11 | #include "flash_fwd_kernel.h" 12 | 13 | // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers 14 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 15 | #define ARCH_SUPPORTS_FLASH 16 | #define KERNEL_PARAM_MODIFIER __grid_constant__ 17 | #else 18 | #define KERNEL_PARAM_MODIFIER 19 | #endif 20 | 21 | // Define a macro for unsupported architecture handling to centralize the error message 22 | #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); 23 | 24 | // Use a macro to clean up kernel definitions 25 | #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ 26 | template \ 27 | __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) 28 | 29 | #define DEFINE_FLASH_QPACK_KERNEL(kernelName, ...) \ 30 | template \ 31 | __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) 32 | 33 | DEFINE_FLASH_QPACK_KERNEL(flash_qpack_kernel) { 34 | #if defined(ARCH_SUPPORTS_FLASH) 35 | flash::compute_qpack(params); 36 | #else 37 | FLASH_UNSUPPORTED_ARCH 38 | #endif 39 | } 40 | 41 | DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { 42 | #if defined(ARCH_SUPPORTS_FLASH) 43 | static_assert(!(Is_causal && Is_local)); // Enforce constraints 44 | flash::compute_attn(params); 45 | #else 46 | FLASH_UNSUPPORTED_ARCH 47 | #endif 48 | } 49 | 50 | DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Paged_KV) { 51 | #if defined(ARCH_SUPPORTS_FLASH) 52 | flash::compute_attn_splitkv(params); 53 | #else 54 | FLASH_UNSUPPORTED_ARCH 55 | #endif 56 | } 57 | 58 | DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { 59 | static_assert(Log_max_splits >= 1); 60 | flash::combine_attn_seqk_parallel(params); 61 | } 62 | 63 | template 64 | void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { 65 | constexpr size_t smem_size = Kernel_traits::kSmemSize; 66 | // printf("smem_size = %d\n", smem_size); 67 | 68 | // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. 69 | // https://github.com/kokkos/kokkos-kernels/issues/349 70 | // https://github.com/HazyResearch/flash-attention/issues/21 71 | 72 | const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; 73 | dim3 grid(num_m_block, params.b, params.h); 74 | const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; 75 | const bool is_even_K = params.d == Kernel_traits::kHeadDim; 76 | const bool return_softmax = params.p_ptr != nullptr; 77 | // BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { 78 | // EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { 79 | 80 | // Will only return softmax if dropout, to reduce compilation time. 81 | // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. 82 | // If return_softmax, set IsEvenMNConst to false to reduce number of templates 83 | // If head dim > 128, set IsEvenMNConst to false to reduce number of templates 84 | // If Is_local, set Is_causal to false 85 | auto kernel = &flash_fwd_kernel; 86 | // auto kernel = &flash_fwd_kernel; 87 | // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); 88 | // auto kernel = &flash_fwd_kernel; 89 | if (smem_size >= 48 * 1024) { 90 | C10_CUDA_CHECK(cudaFuncSetAttribute( 91 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 92 | } 93 | // int ctas_per_sm; 94 | // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( 95 | // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); 96 | // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); 97 | kernel<<>>(params); 98 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 99 | // }); 100 | // }); 101 | } 102 | 103 | template 104 | void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { 105 | static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); 106 | static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); 107 | constexpr size_t smem_size = Kernel_traits::kSmemSize; 108 | const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; 109 | 110 | dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); 111 | const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; 112 | const bool is_even_K = params.d == Kernel_traits::kHeadDim; 113 | // BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { 114 | // EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { 115 | // LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { 116 | BOOL_SWITCH(params.num_splits > 1, Split, [&] { 117 | // BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { 118 | BOOL_SWITCH(params.block_table != nullptr, Paged_KV, [&] { 119 | // ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { 120 | // SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { 121 | // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. 122 | // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. 123 | // If Is_local, set Is_causal to false 124 | // IsEvenMNConst: 0 125 | // IsEvenKConst: 1 126 | // Is_local: 0 127 | // Split: 1 128 | // Append_KV: 129 | // Has_alibi: 0 130 | // Is_softcap: 0 131 | auto kernel = &flash_fwd_splitkv_kernel; 132 | // auto kernel = &flash_fwd_splitkv_kernel; 133 | // auto kernel = &flash_fwd_splitkv_kernel; 134 | if (smem_size >= 48 * 1024) { 135 | C10_CUDA_CHECK(cudaFuncSetAttribute( 136 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 137 | } 138 | kernel<<>>(params); 139 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 140 | // }); 141 | }); 142 | // }); 143 | }); 144 | // }); 145 | // }); 146 | // }); 147 | if (params.num_splits > 1) { 148 | // We want kBlockM to be as small as possible for more parallelism. 149 | // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. 150 | // If headdim is divisible by 64, then we set kBlockM = 8, etc. 151 | constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); 152 | dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); 153 | EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { 154 | if (params.num_splits <= 2) { 155 | flash_fwd_splitkv_combine_kernel<<>>(params); 156 | } else if (params.num_splits <= 4) { 157 | flash_fwd_splitkv_combine_kernel<<>>(params); 158 | } else if (params.num_splits <= 8) { 159 | flash_fwd_splitkv_combine_kernel<<>>(params); 160 | } else if (params.num_splits <= 16) { 161 | flash_fwd_splitkv_combine_kernel<<>>(params); 162 | } else if (params.num_splits <= 32) { 163 | flash_fwd_splitkv_combine_kernel<<>>(params); 164 | } else if (params.num_splits <= 64) { 165 | flash_fwd_splitkv_combine_kernel<<>>(params); 166 | } else if (params.num_splits <= 128) { 167 | flash_fwd_splitkv_combine_kernel<<>>(params); 168 | } 169 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 170 | }); 171 | } 172 | } 173 | 174 | template 175 | void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { 176 | constexpr static int kBlockM = 16; // Fixed for all head dimensions 177 | // constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); 178 | constexpr static int kBlockN = 256; 179 | 180 | run_flash_splitkv_fwd, Is_causal>(params, stream); 181 | } 182 | 183 | 184 | template 185 | void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { 186 | // constexpr static int Headdim = 128; 187 | // auto dprops = at::cuda::getCurrentDeviceProperties(); 188 | // bool is_sm8x = dprops->major == 8 && dprops->minor > 0; 189 | // DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 190 | // if constexpr(!Is_dropout) { 191 | // // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), 192 | // // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. 193 | // if (is_sm8x) { 194 | // if constexpr(!Is_causal) { 195 | // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 196 | // } else { 197 | // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 198 | // } 199 | // } else { 200 | // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 201 | // } 202 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 203 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 204 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 205 | // // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k 206 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 207 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 208 | // // 1st ones are good for H100, A100 209 | // // 2nd one is good for A6000 bc we get slightly better occupancy 210 | // } else { 211 | // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 212 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 213 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 214 | // // run_flash_fwd, Is_dropout, Is_causal>(params, stream); 215 | // } 216 | // }); 217 | } 218 | 219 | 220 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 221 | // QPack 222 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 223 | 224 | template 225 | void run_flash_qpack(Flash_fwd_params ¶ms, cudaStream_t stream) { 226 | constexpr size_t smem_size = Kernel_traits::kSmemSize; 227 | 228 | const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; 229 | dim3 grid(num_n_block, params.b, params.h); 230 | 231 | auto kernel = &flash_qpack_kernel; 232 | 233 | if (smem_size >= 48 * 1024) { 234 | C10_CUDA_CHECK(cudaFuncSetAttribute( 235 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 236 | } 237 | 238 | kernel<<>>(params); 239 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 240 | 241 | 242 | } 243 | 244 | template 245 | void run_kvcache_qpack_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { 246 | constexpr static int Headdim = 128; 247 | constexpr static int kBlockN = num_bits == 4 ? 128 : 256; 248 | 249 | run_flash_qpack>(params, stream); 250 | } -------------------------------------------------------------------------------- /csrc/bit_decode/src/genfile/flash_fwd_hdim128_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "../flash_fwd_launch_template.h" 6 | 7 | template<> 8 | void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | run_mha_fwd_hdim128(params, stream); 10 | } 11 | 12 | // template<> 13 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 14 | // run_mha_fwd_hdim128(params, stream); 15 | // } 16 | 17 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "../flash_fwd_launch_template.h" 6 | 7 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 9 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 10 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "../flash_fwd_launch_template.h" 6 | 7 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 9 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 10 | 11 | 12 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 13 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 14 | // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 15 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "../flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // run_kvcache_qpack_hdim128(params, stream); 10 | // } 11 | // template<> 12 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 13 | // run_kvcache_qpack_hdim128(params, stream); 14 | // } 15 | // template<> 16 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 17 | // run_kvcache_qpack_hdim128(params, stream); 18 | // } -------------------------------------------------------------------------------- /csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "../flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // run_kvcache_qpack_hdim128(params, stream); 10 | // } 11 | // template<> 12 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 13 | // run_kvcache_qpack_hdim128(params, stream); 14 | // } 15 | template<> 16 | void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 17 | run_kvcache_qpack_hdim128(params, stream); 18 | } 19 | 20 | 21 | // template<> 22 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 23 | // run_kvcache_qpack_hdim128(params, stream); 24 | // } 25 | // template<> 26 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 27 | // run_kvcache_qpack_hdim128(params, stream); 28 | // } 29 | // template<> 30 | // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { 31 | // run_kvcache_qpack_hdim128(params, stream); 32 | // } 33 | 34 | 35 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/alibi.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "utils.h" 9 | 10 | namespace flash { 11 | 12 | using namespace cute; 13 | 14 | //////////////////////////////////////////////////////////////////////////////////////////////////// 15 | 16 | template 17 | struct Alibi { 18 | 19 | const float alibi_slope; 20 | const int max_seqlen_k, max_seqlen_q; 21 | 22 | __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) 23 | : alibi_slope(alibi_slope) 24 | , max_seqlen_k(max_seqlen_k) 25 | , max_seqlen_q(max_seqlen_q) { 26 | }; 27 | 28 | 29 | template 30 | __forceinline__ __device__ void apply_alibi(Tensor &tensor, 31 | const int col_idx_offset_, 32 | const int row_idx_offset, 33 | const int warp_row_stride) { 34 | // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) 35 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 36 | const int lane_id = threadIdx.x % 32; 37 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; 38 | if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows 39 | #pragma unroll 40 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 41 | const int col_idx_base = col_idx_offset + nj * 8; 42 | #pragma unroll 43 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 44 | const int col_idx = col_idx_base + j; 45 | #pragma unroll 46 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 47 | tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; 48 | } 49 | } 50 | } 51 | } else { // Bias depends on both row_idx and col_idx 52 | #pragma unroll 53 | for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { 54 | const int row_idx_base = row_idx_offset + mi * warp_row_stride; 55 | #pragma unroll 56 | for (int i = 0; i < size<0, 0>(tensor); ++i) { 57 | const int row_idx = row_idx_base + i * 8; 58 | #pragma unroll 59 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 60 | const int col_idx_base = col_idx_offset + nj * 8; 61 | #pragma unroll 62 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 63 | const int col_idx = col_idx_base + j; 64 | tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); 65 | } 66 | } 67 | } 68 | } 69 | } 70 | } 71 | 72 | }; 73 | 74 | } // namespace flash 75 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/block_info.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | namespace flash { 8 | 9 | //////////////////////////////////////////////////////////////////////////////////////////////////// 10 | 11 | template 12 | struct BlockInfo { 13 | 14 | template 15 | __device__ BlockInfo(const Params ¶ms, const int bidb) 16 | : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) 17 | , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) 18 | , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) 19 | // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. 20 | // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. 21 | , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) 22 | , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) 23 | , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) 24 | { 25 | } 26 | 27 | template 28 | __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 29 | return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; 30 | } 31 | 32 | template 33 | __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 34 | return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; 35 | 36 | // return bidb * batch_stride; 37 | } 38 | 39 | const int sum_s_q; 40 | const int sum_s_k; 41 | const int actual_seqlen_q; 42 | // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. 43 | const int leftpad_k; 44 | const int seqlen_k_cache; 45 | const int actual_seqlen_k; 46 | }; 47 | 48 | //////////////////////////////////////////////////////////////////////////////////////////////////// 49 | 50 | } // namespace flash 51 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/dequantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define PRINT(name, content) \ 9 | print(name); \ 10 | print(" : "); \ 11 | print(content); \ 12 | print("\n"); 13 | 14 | #define PRINTTENSOR(name, content) \ 15 | print(name); \ 16 | print(" : "); \ 17 | print_tensor(content); \ 18 | print("\n"); 19 | 20 | namespace quant { 21 | 22 | // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core 23 | // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we 24 | // extensively use `#pragma unroll` throughout the kernel code to guarantee this. 25 | template 26 | struct Vec { 27 | T elems[n]; 28 | __device__ T& operator[](int i) { 29 | return elems[i]; 30 | } 31 | }; 32 | 33 | 34 | using I4 = Vec; 35 | 36 | // Matrix fragments for tensor core instructions; their precise layout is documented here: 37 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type 38 | using FragA = Vec; 39 | using FragB = Vec; 40 | using FragC = Vec; 41 | using FragS = Vec; // quantization scales 42 | 43 | 44 | // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to 45 | // automatically recognize it in all cases. 46 | template 47 | __device__ inline int lop3(int a, int b, int c) { 48 | int res; 49 | asm volatile( 50 | "lop3.b32 %0, %1, %2, %3, %4;\n" 51 | : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) 52 | ); 53 | return res; 54 | } 55 | 56 | // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. 57 | // We mostly follow the strategy in the link below, with some small changes: 58 | // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h 59 | __device__ inline FragA lop3_dequant(int q) { 60 | const int LO = 0x000f000f; 61 | const int HI = 0x00f000f0; 62 | const int EX = 0x64006400; 63 | 64 | // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue 65 | // immediately before required. 66 | const uint32_t top_i4s = q >> 8; 67 | 68 | // Guarantee that the `(a & b) | c` operations are LOP3s. 69 | int lo_1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // 0,4 70 | int hi_1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // 1,5 71 | int lo_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, LO, EX); // 2,6 72 | int hi_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 3,7 73 | 74 | // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. 75 | const int SUB = 0x64006400; // 0x64086408 76 | const int MUL = 0x2c002c00; // {1/16, 1/16} 77 | const int ADD = 0xd400d400; // 0xd480d480 78 | 79 | FragA frag_a; 80 | frag_a[0] = __hsub2( 81 | *reinterpret_cast(&lo_1), 82 | *reinterpret_cast(&SUB) 83 | ); // 0,4 84 | frag_a[1] = __hfma2( 85 | *reinterpret_cast(&hi_1), 86 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 87 | ); // 1,5 88 | frag_a[2] = __hsub2( 89 | *reinterpret_cast(&lo_2), 90 | *reinterpret_cast(&SUB) 91 | ); // 2,6 92 | frag_a[3] = __hfma2( 93 | *reinterpret_cast(&hi_2), 94 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 95 | ); // 3,7 96 | 97 | return frag_a; 98 | } 99 | 100 | 101 | // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. 102 | // We mostly follow the strategy in the link below, with some small changes: 103 | // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h 104 | __device__ inline FragB lop3_dequant_2bit(int q) { 105 | const int LO = 0x00030003; 106 | const int HI = 0x00300030; 107 | const int EX = 0x64006400; 108 | 109 | // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue 110 | // immediately before required. 111 | const uint32_t top_i4s = q >> 8; 112 | 113 | // Guarantee that the `(a & b) | c` operations are LOP3s. 114 | int lo_1_a = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // 0,8 115 | int lo_1_b = lop3<(0xf0 & 0xcc) | 0xaa>(q >> 2, LO, EX); // 1,9 116 | int hi_1_a = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // 2,10 117 | int hi_1_b = lop3<(0xf0 & 0xcc) | 0xaa>(q >> 2, HI, EX); // 3,11 118 | int lo_2_a = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, LO, EX); // 4,12 119 | int lo_2_b = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s >> 2, LO, EX); // 5,13 120 | int hi_2_a = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 6,14 121 | int hi_2_b = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s >> 2, HI, EX); // 7,15 122 | 123 | 124 | // int hi_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 3,7 125 | 126 | // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. 127 | const int SUB = 0x64006400; // {1024, 1024} 0x64086408 128 | const int MUL = 0x2c002c00; // {1/16, 1/16} 129 | const int ADD = 0xd400d400; // {-64, -64} 0xd480d480 130 | 131 | FragB frag_b; 132 | frag_b[0] = __hsub2( 133 | *reinterpret_cast(&lo_1_a), 134 | *reinterpret_cast(&SUB) 135 | ); // 0,8 136 | frag_b[1] = __hsub2( 137 | *reinterpret_cast(&lo_1_b), 138 | *reinterpret_cast(&SUB) 139 | ); // 1,9 140 | frag_b[2] = __hfma2( 141 | *reinterpret_cast(&hi_1_a), 142 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 143 | ); // 2,10 144 | frag_b[3] = __hfma2( 145 | *reinterpret_cast(&hi_1_b), 146 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 147 | ); // 3,11 148 | frag_b[4] = __hsub2( 149 | *reinterpret_cast(&lo_2_a), 150 | *reinterpret_cast(&SUB) 151 | ); // 4,12 152 | frag_b[5] = __hsub2( 153 | *reinterpret_cast(&lo_2_b), 154 | *reinterpret_cast(&SUB) 155 | ); // 5,13 156 | frag_b[6] = __hfma2( 157 | *reinterpret_cast(&hi_2_a), 158 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 159 | ); // 6,14 160 | frag_b[7] = __hfma2( 161 | *reinterpret_cast(&hi_2_b), 162 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 163 | ); // 7,15 164 | 165 | return frag_b; 166 | } 167 | 168 | 169 | 170 | ////////////////////////////////////////////////////////////////////////////// 171 | // Loading params 172 | ////////////////////////////////////////////////////////////////////////////// 173 | 174 | template 175 | __forceinline__ __device__ 176 | void 177 | load_params_Kchannel( 178 | Tensor0 & scales, 179 | Tensor1 & zeros, 180 | Tensor2 const& params, 181 | int tidx, 182 | int i, 183 | const int num_params 184 | ) { 185 | CUTE_UNROLL 186 | for (int m = 0; m < size<1>(scales); ++m) { 187 | CUTE_UNROLL 188 | for (int j = 0; j < size<0>(scales); ++j) { 189 | // seems no one can know why is this offset ... 190 | scales(j, m, i) = params(m * num_params + j % num_params, 0 + 8 * i + 4 * (j / num_params) + tidx % 4); 191 | zeros(j, m, i) = params(m * num_params + j % num_params, 64 + 8 * i + 4 * (j / num_params) + tidx % 4); 192 | } 193 | } 194 | } 195 | 196 | template 197 | __forceinline__ __device__ 198 | void 199 | load_params_Ktensor( 200 | Tensor0_g & scales, 201 | Tensor1_g & zeros, 202 | Tensor2_g const& params, 203 | int tidx, 204 | const int num_params 205 | ) { 206 | CUTE_UNROLL 207 | for (int j = 0; j < size<0>(scales); ++j) { 208 | scales(j) = params(128 * (j / num_params / 2) + 0 + 32 * ((j / num_params) % 2) + tidx / 4, j % num_params); 209 | zeros(j) = params(128 * (j / num_params / 2) + 64 + 32 * ((j / num_params) % 2) + tidx / 4, j % num_params); 210 | // scales(j) = params(0 + 32 * (j / num_params) + tidx / 4, j % num_params); 211 | // zeros(j) = params(64 + 32 * (j / num_params) + tidx / 4, j % num_params); 212 | } 213 | 214 | // CUTE_UNROLL 215 | // for (int j = 0; j < size<0>(scales); ++j) { 216 | // params(0 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = scales(j); 217 | // params(64 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = zeros(j); 218 | // } 219 | } 220 | 221 | template 222 | __forceinline__ __device__ 223 | void 224 | load_params_Vtensor( 225 | Tensor0 & scales, 226 | Tensor1 & zeros, 227 | Tensor2 const& params, 228 | int tidx, 229 | int i, 230 | const int num_params 231 | ) { 232 | const int num_params_2 = num_bits == 2 ? num_params / 2 : num_params; 233 | CUTE_UNROLL 234 | for (int j = 0; j < size<0>(scales); ++j) { 235 | // seems no one can know why is this offset ... 236 | scales(j, i) = params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); 237 | zeros(j, i) = params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); 238 | } 239 | } 240 | 241 | ////////////////////////////////////////////////////////////////////////////// 242 | // Dequantization 243 | ////////////////////////////////////////////////////////////////////////////// 244 | 245 | 246 | template 251 | struct dequant_kc_vt; 252 | 253 | template 257 | struct dequant_kc_vt<2, SourceEngine, SourceLayout, TargetEngine, TargetLayout, ScaleEngine, ScaleLayout, ZeroEngine, ZeroLayout> { 258 | static constexpr int num_bits = 2; 259 | CUTE_DEVICE static 260 | void apply(cute::Tensor const& source, 261 | cute::Tensor const& target, 262 | cute::Tensor const& scales, 263 | cute::Tensor const& zeros, 264 | const int num_params) { 265 | using TQ = cute::uint16_t; 266 | using TQ2 = cute::uint32_t; 267 | using T = typename TargetEngine::value_type; 268 | using T2 = __half2; 269 | const int num_params_ = num_params / 2; // TODO: only for g128 270 | const int pack_num = 4 / num_params_; // TODO: check 4 271 | 272 | // vectorize the source and target 273 | auto scales_vec = cute::recast(scales); 274 | auto zeros_vec = cute::recast(zeros); 275 | auto source_vec = cute::recast(source); 276 | auto target_vec = cute::recast(target); 277 | 278 | const int channel_stride = size<0>(source_vec); 279 | 280 | CUTE_UNROLL 281 | for (int i = 0; i < cute::size<0>(source_vec); ++i) { 282 | 283 | CUTE_UNROLL 284 | for (int p = 0; p < cute::size<1>(source_vec); ++p) { 285 | auto src_crd = cute::make_coord(i, p); 286 | auto src_raw = source_vec(src_crd); 287 | auto src_val = lop3_dequant_2bit(src_raw); 288 | 289 | CUTE_UNROLL 290 | for (int j = 0; j < size<1>(target_vec); ++j) { 291 | target_vec(i, j) = __hfma2(src_val[j], scales_vec(i + j / pack_num * channel_stride), zeros_vec(i + j / pack_num * channel_stride)); 292 | } 293 | 294 | // target_vec(i,0) = __hfma2(src_val[0], scales_vec(i), zeros_vec(i)); 295 | // target_vec(i,1) = __hfma2(src_val[1], scales_vec(i + 1 / pack_num * channel_stride), zeros_vec(i + 1 / pack_num * channel_stride)); 296 | // target_vec(i,2) = __hfma2(src_val[2], scales_vec(i + 2 / pack_num * channel_stride), zeros_vec(i + 2 / pack_num * channel_stride)); 297 | // target_vec(i,3) = __hfma2(src_val[3], scales_vec(i + 3 / pack_num * channel_stride), zeros_vec(i + 3 / pack_num * channel_stride)); 298 | // target_vec(i,4) = __hfma2(src_val[4], scales_vec(i + 4 / pack_num * channel_stride), zeros_vec(i + 4 / pack_num * channel_stride)); 299 | // target_vec(i,5) = __hfma2(src_val[5], scales_vec(i + 5 / pack_num * channel_stride), zeros_vec(i + 5 / pack_num * channel_stride)); 300 | // target_vec(i,6) = __hfma2(src_val[6], scales_vec(i + 6 / pack_num * channel_stride), zeros_vec(i + 6 / pack_num * channel_stride)); 301 | // target_vec(i,7) = __hfma2(src_val[7], scales_vec(i + 7 / pack_num * channel_stride), zeros_vec(i + 7 / pack_num * channel_stride)); 302 | 303 | // target_vec(i,0) = __hfma2(src_val[0], scales_vec(0), zeros_vec(0)); 304 | // target_vec(i,1) = __hfma2(src_val[1], scales_vec(0), zeros_vec(0)); 305 | // target_vec(i,2) = __hfma2(src_val[2], scales_vec(0), zeros_vec(0)); 306 | // target_vec(i,3) = __hfma2(src_val[3], scales_vec(0), zeros_vec(0)); 307 | // target_vec(i,4) = __hfma2(src_val[4], scales_vec(0), zeros_vec(0)); 308 | // target_vec(i,5) = __hfma2(src_val[5], scales_vec(0), zeros_vec(0)); 309 | // target_vec(i,6) = __hfma2(src_val[6], scales_vec(0), zeros_vec(0)); 310 | // target_vec(i,7) = __hfma2(src_val[7], scales_vec(0), zeros_vec(0)); 311 | 312 | // target_vec(i,0) = src_val[0]; 313 | // target_vec(i,1) = src_val[1]; 314 | // target_vec(i,2) = src_val[2]; 315 | // target_vec(i,3) = src_val[3]; 316 | // target_vec(i,4) = src_val[4]; 317 | // target_vec(i,5) = src_val[5]; 318 | // target_vec(i,6) = src_val[6]; 319 | // target_vec(i,7) = src_val[7]; 320 | } 321 | } 322 | } 323 | }; 324 | 325 | template 329 | struct dequant_kc_vt<4, SourceEngine, SourceLayout, TargetEngine, TargetLayout, ScaleEngine, ScaleLayout, ZeroEngine, ZeroLayout> { 330 | static constexpr int num_bits = 4; 331 | CUTE_DEVICE static 332 | void apply(cute::Tensor const& source, 333 | cute::Tensor const& target, 334 | cute::Tensor const& scales, 335 | cute::Tensor const& zeros, 336 | const int num_params) { 337 | using TQ = cute::uint16_t; 338 | using TQ2 = cute::uint32_t; 339 | using T = typename TargetEngine::value_type; 340 | using T2 = __half2; 341 | const int pack_num = 4 / num_params; 342 | 343 | // vectorize the source and target 344 | auto scales_vec = cute::recast(scales); 345 | auto zeros_vec = cute::recast(zeros); 346 | auto source_vec = cute::recast(source); 347 | auto target_vec = cute::recast(target); 348 | 349 | const int channel_stride = cute::size<0>(source_vec); 350 | const int scales_stride = cute::size<0>(scales_vec); 351 | 352 | CUTE_UNROLL 353 | for (int i = 0; i < cute::size<0>(source_vec); ++i) // 2 354 | { 355 | CUTE_UNROLL 356 | for (int p = 0; p < cute::size<1>(source_vec); ++p) // 1 357 | { 358 | auto src_crd = cute::make_coord(i, p); 359 | auto src_raw = source_vec(src_crd); 360 | auto src_val = lop3_dequant(src_raw); 361 | 362 | auto col_offset = p * num_bits; 363 | 364 | auto tgt0_crd = cute::make_coord(i, col_offset + 0); 365 | auto tgt1_crd = cute::make_coord(i, col_offset + 1); 366 | auto tgt2_crd = cute::make_coord(i, col_offset + 2); 367 | auto tgt3_crd = cute::make_coord(i, col_offset + 3); 368 | 369 | // TODO: hard code for now 2 370 | int params_crd = i; 371 | 372 | target_vec(tgt0_crd) = __hfma2(src_val[0], scales_vec(params_crd + p * scales_stride), zeros_vec(params_crd + p * scales_stride)); 373 | target_vec(tgt1_crd) = __hfma2(src_val[1], scales_vec(params_crd + 1 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 1 / pack_num * channel_stride + p * scales_stride)); 374 | target_vec(tgt2_crd) = __hfma2(src_val[2], scales_vec(params_crd + 2 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 2 / pack_num * channel_stride + p * scales_stride)); 375 | target_vec(tgt3_crd) = __hfma2(src_val[3], scales_vec(params_crd + 3 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 3 / pack_num * channel_stride + p * scales_stride)); 376 | 377 | // target_vec(tgt0_crd) = src_val[0]; 378 | // target_vec(tgt1_crd) = src_val[1]; 379 | // target_vec(tgt2_crd) = src_val[2]; 380 | // target_vec(tgt3_crd) = src_val[3]; 381 | } 382 | } 383 | } 384 | }; 385 | 386 | template 391 | CUTE_DEVICE 392 | void 393 | dequant_Kchannel_Vtensor( 394 | cute::Tensor const& source, 395 | cute::Tensor const& target, 396 | cute::Tensor const& scales_vec, 397 | cute::Tensor const& zeros_vec, 398 | const int num_params=1 399 | ) { 400 | dequant_kc_vt::apply(source, target, scales_vec, zeros_vec, num_params); 401 | } 402 | 403 | template 406 | CUTE_DEVICE 407 | void 408 | dequantize_Ktensor( 409 | cute::Tensor const& source, 410 | cute::Tensor & target, 411 | TensorParamsG1 & scales_k_g_vec, 412 | TensorParamsG2 & zeros_k_g_vec, 413 | int num_bits, 414 | int group_size, 415 | int ii 416 | ) { 417 | using TQ = cute::uint16_t; 418 | using TQ2 = cute::uint32_t; 419 | using T = typename TargetEngine::value_type; 420 | using T2 = __half2; 421 | 422 | static constexpr int kNumBits = 4; 423 | const int num_params = 128 / group_size; 424 | const int ki = size<2>(target) / num_params; 425 | 426 | // vectorize the source and target 427 | auto scales_k_g = cute::recast(scales_k_g_vec); 428 | auto zeros_k_g = cute::recast(zeros_k_g_vec); 429 | auto source_vec = cute::recast(source); 430 | auto target_vec = cute::recast(target); 431 | 432 | const int tile_j = size<2>(target) != size<2>(source) ? 2 : 1; 433 | 434 | CUTE_UNROLL 435 | for (int i = 0; i < cute::size<0>(source_vec); ++i) 436 | { 437 | auto src_crd = cute::make_coord(0, 0, 0); 438 | for (int p = 0; p < tile_j; ++p) { 439 | src_crd = tile_j == 1 ? cute::make_coord(i, 0, ii) : cute::make_coord(i, 0, 8 * (ii / 4) + ii % 4 + p * 4); 440 | auto src_raw = source_vec(src_crd); 441 | auto src_val = lop3_dequant(src_raw); 442 | 443 | auto col_offset = p * kNumBits; 444 | 445 | auto tgt0_crd = cute::make_coord(i, col_offset + 0, ii); 446 | auto tgt1_crd = cute::make_coord(i, col_offset + 1, ii); 447 | auto tgt2_crd = cute::make_coord(i, col_offset + 2, ii); 448 | auto tgt3_crd = cute::make_coord(i, col_offset + 3, ii); 449 | 450 | // Create half2 values for scales and zeros 451 | half2 scales_k_g_0 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 0 * num_params))); 452 | half2 scales_k_g_1 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 1 * num_params))); 453 | half2 scales_k_g_2 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 2 * num_params))); 454 | half2 scales_k_g_3 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 3 * num_params))); 455 | 456 | half2 zeros_k_g_0 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 0 * num_params))); 457 | half2 zeros_k_g_1 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 1 * num_params))); 458 | half2 zeros_k_g_2 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 2 * num_params))); 459 | half2 zeros_k_g_3 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 3 * num_params))); 460 | 461 | target_vec(tgt0_crd) = __hfma2(src_val[0], scales_k_g_0, zeros_k_g_0); 462 | target_vec(tgt1_crd) = __hfma2(src_val[1], scales_k_g_1, zeros_k_g_1); 463 | target_vec(tgt2_crd) = __hfma2(src_val[2], scales_k_g_2, zeros_k_g_2); 464 | target_vec(tgt3_crd) = __hfma2(src_val[3], scales_k_g_3, zeros_k_g_3); 465 | 466 | // target_vec(tgt0_crd) = src_val[0]; 467 | // target_vec(tgt1_crd) = src_val[1]; 468 | // target_vec(tgt2_crd) = src_val[2]; 469 | // target_vec(tgt3_crd) = src_val[3]; 470 | } 471 | 472 | 473 | } 474 | 475 | } 476 | 477 | } // namespace quant -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/dropout.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include "philox.cuh" 8 | #include "utils.h" 9 | 10 | namespace flash { 11 | 12 | struct Dropout { 13 | 14 | const unsigned long long seed, offset; 15 | const uint8_t p_dropout_in_uint8_t; 16 | 17 | __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, 18 | const uint8_t p_dropout_in_uint8_t, 19 | const int bid, const int hid, const int tid, const int nheads) 20 | : seed(seed) 21 | , offset(offset + (bid * nheads + hid) * 32 + tid % 32) 22 | , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { 23 | } 24 | 25 | template 26 | __forceinline__ __device__ void apply_dropout(Tensor &tensor_, 27 | int block_row_start, int block_col_start, int block_row_stride) { 28 | // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) 29 | Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); 30 | using T = typename Engine::value_type; 31 | auto encode_dropout = [](bool keep, T val) { 32 | return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); 33 | }; 34 | static_assert(decltype(size<2>(tensor))::value % 2 == 0); 35 | const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); 36 | const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); 37 | // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } 38 | #pragma unroll 39 | for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { 40 | uint2 rowcol = make_uint2(block_row_start, block_col_start); 41 | #pragma unroll 42 | for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { 43 | // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} 44 | uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); 45 | // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} 46 | uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); 47 | // Special implementation for 16-bit types: we duplicate the threshold to the 48 | // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction 49 | // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, 50 | // and the high 16 bits will be either 0xffff or 0x0000, depending on whether 51 | // the random value is less than the threshold. 52 | // We then do a bit-wise AND between the mask and the original value (in 32-bit). 53 | // We're exploiting the fact that floating point comparison is equivalent to integer 54 | // comparison, since we're comparing unsigned integers whose top 8-bits are zero. 55 | if (!encode_dropout_in_sign_bit 56 | && (std::is_same::value || std::is_same::value)) { 57 | uint16_t rnd_16[16]; 58 | #pragma unroll 59 | for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } 60 | uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); 61 | #pragma unroll 62 | for (int j = 0; j < 2; j++) { 63 | Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); 64 | // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } 65 | // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 66 | #pragma unroll 67 | for (int i = 0; i < 4; i++) { 68 | uint32_t mask; 69 | asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); 70 | tensor_uint32(i) &= mask; 71 | } 72 | // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 73 | } 74 | } else { 75 | #pragma unroll 76 | for (int j = 0; j < 2; j++) { 77 | #pragma unroll 78 | for (int i = 0; i < 8; i++) { 79 | tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); 80 | } 81 | Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); 82 | // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 83 | } 84 | } 85 | // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { 86 | // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); 87 | // // } 88 | } 89 | } 90 | } 91 | 92 | }; 93 | 94 | } // namespace flash 95 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/flash.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #ifdef OLD_GENERATOR_PATH 11 | #include 12 | #else 13 | #include 14 | #endif 15 | 16 | #include // For at::cuda::philox::unpack 17 | 18 | constexpr int TOTAL_DIM = 0; 19 | constexpr int H_DIM = 1; 20 | constexpr int D_DIM = 2; 21 | 22 | //////////////////////////////////////////////////////////////////////////////////////////////////// 23 | 24 | struct Qkv_params { 25 | using index_t = int64_t; 26 | 27 | // The QKV matrices. 28 | void *__restrict__ q_ptr; 29 | void *__restrict__ k_ptr; 30 | void *__restrict__ K_pack_ptr; 31 | void *__restrict__ k_pack_new_ptr; 32 | void *__restrict__ k_params_new_ptr; 33 | void *__restrict__ k_params_ptr; 34 | void *__restrict__ v_ptr; 35 | void *__restrict__ v_pack_ptr; 36 | void *__restrict__ v_pack_new_ptr; 37 | void *__restrict__ v_params_ptr; 38 | void *__restrict__ v_params_new_ptr; 39 | 40 | // The stride between rows of the Q, K and V matrices. 41 | index_t q_batch_stride; 42 | 43 | index_t k_batch_stride; 44 | index_t K_pack_batch_stride; 45 | index_t k_pack_new_batch_stride; 46 | index_t k_params_batch_stride; 47 | index_t k_params_new_batch_stride; 48 | 49 | index_t v_batch_stride; 50 | index_t v_pack_batch_stride; 51 | index_t v_pack_new_batch_stride; 52 | index_t v_params_batch_stride; 53 | index_t v_params_new_batch_stride; 54 | 55 | index_t q_row_stride; 56 | 57 | index_t k_row_stride; 58 | index_t K_pack_row_stride; 59 | index_t k_pack_new_row_stride; 60 | index_t k_params_row_stride; 61 | index_t k_params_new_row_stride; 62 | 63 | index_t v_row_stride; 64 | index_t v_pack_row_stride; 65 | index_t v_pack_new_row_stride; 66 | index_t v_params_row_stride; 67 | index_t v_params_new_row_stride; 68 | 69 | index_t q_head_stride; 70 | 71 | index_t k_head_stride; 72 | index_t K_pack_head_stride; 73 | index_t k_pack_new_head_stride; 74 | index_t k_params_head_stride; 75 | index_t k_params_new_head_stride; 76 | 77 | index_t v_head_stride; 78 | index_t v_pack_head_stride; 79 | index_t v_pack_new_head_stride; 80 | index_t v_params_head_stride; 81 | index_t v_params_new_head_stride; 82 | 83 | index_t k_params_dim_stride; 84 | index_t k_params_new_dim_stride; 85 | 86 | index_t v_params_dim_stride; 87 | index_t v_params_new_dim_stride; 88 | 89 | // The number of heads. 90 | int h, h_k; 91 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 92 | // different from nheads (query). 93 | int h_h_k_ratio; // precompute h / h_k, 94 | 95 | std::string quant_mode; 96 | int group_size; 97 | int new_lens; 98 | }; 99 | 100 | //////////////////////////////////////////////////////////////////////////////////////////////////// 101 | 102 | struct Flash_fwd_params : public Qkv_params { 103 | 104 | // The O matrix (output). 105 | void * __restrict__ o_ptr; 106 | void * __restrict__ oaccum_ptr; 107 | 108 | // The stride between rows of O. 109 | index_t o_batch_stride; 110 | index_t o_row_stride; 111 | index_t o_head_stride; 112 | 113 | // The pointer to the P matrix. 114 | void * __restrict__ p_ptr; 115 | 116 | // The pointer to the softmax sum. 117 | void * __restrict__ softmax_lse_ptr; 118 | void * __restrict__ softmax_lseaccum_ptr; 119 | 120 | // The dimensions. 121 | int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; 122 | 123 | // The scaling factors for the kernel. 124 | float scale_softmax; 125 | float scale_softmax_log2; 126 | 127 | // array of length b+1 holding starting offset of each sequence. 128 | int * __restrict__ cu_seqlens_q; 129 | int * __restrict__ cu_seqlens_k; 130 | int * __restrict__ leftpad_k; 131 | 132 | // If provided, the actual length of each k sequence. 133 | int * __restrict__ seqused_k; 134 | 135 | int *__restrict__ blockmask; 136 | 137 | // The K_new and V_new matrices. 138 | void * __restrict__ knew_ptr; 139 | void * __restrict__ vnew_ptr; 140 | 141 | // The stride between rows of the Q, K and V matrices. 142 | index_t knew_batch_stride; 143 | index_t vnew_batch_stride; 144 | index_t knew_row_stride; 145 | index_t vnew_row_stride; 146 | index_t knew_head_stride; 147 | index_t vnew_head_stride; 148 | 149 | // The cos and sin matrices for rotary embedding. 150 | void * __restrict__ rotary_cos_ptr; 151 | void * __restrict__ rotary_sin_ptr; 152 | 153 | // The indices to index into the KV cache. 154 | int * __restrict__ cache_batch_idx; 155 | 156 | // Paged KV cache 157 | int * __restrict__ block_table; 158 | index_t block_table_batch_stride; 159 | int page_block_size; 160 | int page_block_size_pack; 161 | 162 | // The dropout probability (probability of keeping an activation). 163 | float p_dropout; 164 | // uint32_t p_dropout_in_uint; 165 | // uint16_t p_dropout_in_uint16_t; 166 | uint8_t p_dropout_in_uint8_t; 167 | 168 | // Scale factor of 1 / (1 - p_dropout). 169 | float rp_dropout; 170 | float scale_softmax_rp_dropout; 171 | 172 | // Local window size 173 | int window_size_left, window_size_right; 174 | float softcap; 175 | 176 | // Random state. 177 | at::PhiloxCudaState philox_args; 178 | 179 | // Pointer to the RNG seed (idx 0) and offset (idx 1). 180 | uint64_t * rng_state; 181 | 182 | bool is_bf16; 183 | bool is_causal; 184 | 185 | // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. 186 | // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. 187 | bool is_seqlens_k_cumulative; 188 | 189 | bool is_rotary_interleaved; 190 | 191 | int num_splits; // For split-KV version 192 | 193 | void * __restrict__ alibi_slopes_ptr; 194 | index_t alibi_slopes_batch_stride; 195 | 196 | bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. 197 | bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). 198 | }; 199 | 200 | 201 | //////////////////////////////////////////////////////////////////////////////////////////////////// 202 | 203 | template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); 204 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 205 | template void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream); 206 | 207 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/mask.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace flash { 10 | 11 | using namespace cute; 12 | 13 | template 14 | __forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, 15 | const int col_idx_offset_ = 0) { 16 | // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) 17 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 18 | const int lane_id = threadIdx.x % 32; 19 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; 20 | #pragma unroll 21 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 22 | const int col_idx_base = col_idx_offset + nj * 8; 23 | #pragma unroll 24 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 25 | const int col_idx = col_idx_base + j; 26 | if (col_idx >= max_seqlen_k) { 27 | // Without the "make_coord" we get wrong results 28 | #pragma unroll 29 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 30 | tensor(mi, make_coord(j, nj)) = -INFINITY; 31 | } 32 | } 33 | } 34 | } 35 | } 36 | 37 | template 38 | __forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, 39 | const int max_seqlen_k, const int row_idx_offset, 40 | const int max_seqlen_q, const int warp_row_stride, 41 | const int window_size_left, const int window_size_right) { 42 | // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) 43 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 44 | const int lane_id = threadIdx.x % 32; 45 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; 46 | #pragma unroll 47 | for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { 48 | const int row_idx_base = row_idx_offset + mi * warp_row_stride; 49 | #pragma unroll 50 | for (int i = 0; i < size<0, 0>(tensor); ++i) { 51 | const int row_idx = row_idx_base + i * 8; 52 | const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); 53 | const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); 54 | #pragma unroll 55 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 56 | const int col_idx_base = col_idx_offset + nj * 8; 57 | #pragma unroll 58 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 59 | const int col_idx = col_idx_base + j; 60 | if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { 61 | tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; 62 | } 63 | } 64 | } 65 | // if (cute::thread0()) { 66 | // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); 67 | // print(tensor(make_coord(i, mi), _)); 68 | // // print(tensor(_, j + nj * size<1, 0>(tensor))); 69 | // } 70 | } 71 | } 72 | } 73 | 74 | template 75 | __forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, 76 | const int max_seqlen_k, const int row_idx_offset, 77 | const int max_seqlen_q, const int warp_row_stride) { 78 | // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 79 | apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, 80 | max_seqlen_q, warp_row_stride, -1, 0); 81 | } 82 | 83 | template 84 | __forceinline__ __device__ void apply_mask_causal_w_idx( 85 | Tensor &tensor, Tensor const &idx_rowcol, 86 | const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) 87 | { 88 | // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) 89 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 90 | static_assert(Layout1::rank == 2, "Only support 2D Tensor"); 91 | CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); 92 | CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); 93 | #pragma unroll 94 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 95 | const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); 96 | #pragma unroll 97 | for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { 98 | if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { 99 | tensor(mi, ni) = -INFINITY; 100 | } 101 | } 102 | // if (cute::thread0()) { 103 | // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); 104 | // print(tensor(_, make_coord(j, ni))); 105 | // // print(tensor(_, j + ni * size<1, 0>(tensor))); 106 | // } 107 | } 108 | } 109 | 110 | template 111 | struct Mask { 112 | 113 | const int max_seqlen_k, max_seqlen_q; 114 | const int window_size_left, window_size_right; 115 | const float alibi_slope; 116 | 117 | __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, 118 | const int window_size_left, const int window_size_right, 119 | const float alibi_slope=0.f) 120 | : max_seqlen_k(max_seqlen_k) 121 | , max_seqlen_q(max_seqlen_q) 122 | , window_size_left(window_size_left) 123 | , window_size_right(window_size_right) 124 | , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { 125 | }; 126 | 127 | // Causal_mask: whether this particular iteration needs causal masking 128 | template 129 | __forceinline__ __device__ void apply_mask(Tensor &tensor_, 130 | const int col_idx_offset_, 131 | const int row_idx_offset, 132 | const int warp_row_stride) { 133 | static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); 134 | static_assert(Layout::rank == 3, "Only support 3D Tensor"); 135 | static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); 136 | static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; 137 | // if (cute::thread0()) { printf("max_seqlen_k = %d, Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", max_seqlen_k, Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } 138 | if constexpr (Need_masking) { 139 | // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) 140 | Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); 141 | // Do we need both row and column indices, or just column incides? 142 | static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; 143 | const int lane_id = threadIdx.x % 32; 144 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2 + ((threadIdx.x) / 32) * 8; 145 | 146 | if constexpr (Col_idx_only) { 147 | #pragma unroll 148 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 149 | // const int col_idx_base = col_idx_offset + nj * 8; 150 | const int col_idx_base = col_idx_offset + nj * 32; 151 | #pragma unroll 152 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 153 | const int col_idx = col_idx_base + j; 154 | #pragma unroll 155 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 156 | // No causal, no local 157 | if constexpr (Has_alibi) { 158 | tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; 159 | } 160 | if constexpr (!Is_even_MN) { 161 | if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } 162 | } 163 | } 164 | } 165 | } 166 | } else { 167 | #pragma unroll 168 | for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { 169 | const int row_idx_base = row_idx_offset + mi * warp_row_stride; 170 | #pragma unroll 171 | for (int i = 0; i < size<0, 0>(tensor); ++i) { 172 | const int row_idx = row_idx_base + i * 8; 173 | const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); 174 | const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); 175 | #pragma unroll 176 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 177 | const int col_idx_base = col_idx_offset + nj * 8; 178 | #pragma unroll 179 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 180 | const int col_idx = col_idx_base + j; 181 | if constexpr (Has_alibi) { 182 | if constexpr (Is_causal) { 183 | tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; 184 | } else { 185 | tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); 186 | 187 | } 188 | } 189 | if constexpr (Causal_mask) { 190 | if (col_idx >= col_idx_limit_right) { 191 | tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; 192 | } 193 | } 194 | if constexpr (Is_local) { 195 | if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { 196 | tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; 197 | } 198 | } 199 | if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { 200 | // Causal and Local already handles MN masking 201 | if (col_idx >= max_seqlen_k) { 202 | tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | } 209 | } 210 | } 211 | 212 | 213 | }; 214 | 215 | }; 216 | 217 | } // namespace flash 218 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/philox.cuh: -------------------------------------------------------------------------------- 1 | // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h 2 | #pragma once 3 | // Philox CUDA. 4 | 5 | namespace flash { 6 | 7 | struct ull2 { 8 | unsigned long long x; 9 | unsigned long long y; 10 | }; 11 | 12 | __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { 13 | uint2 *res; 14 | unsigned long long tmp; 15 | asm ("mul.wide.u32 %0, %1, %2;\n\t" 16 | : "=l"(tmp) 17 | : "r"(a), "r"(b)); 18 | res = (uint2*)(&tmp); 19 | return *res; 20 | } 21 | 22 | __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { 23 | constexpr unsigned long kPhiloxSA = 0xD2511F53; 24 | constexpr unsigned long kPhiloxSB = 0xCD9E8D57; 25 | uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); 26 | uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); 27 | uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; 28 | return ret; 29 | } 30 | 31 | __forceinline__ __device__ uint4 philox(unsigned long long seed, 32 | unsigned long long subsequence, 33 | unsigned long long offset) { 34 | constexpr unsigned long kPhilox10A = 0x9E3779B9; 35 | constexpr unsigned long kPhilox10B = 0xBB67AE85; 36 | uint2 key = reinterpret_cast(seed); 37 | uint4 counter; 38 | ull2 *tmp = reinterpret_cast(&counter); 39 | tmp->x = offset; 40 | tmp->y = subsequence; 41 | #pragma unroll 42 | for (int i = 0; i < 6; i++) { 43 | counter = philox_single_round(counter, key); 44 | key.x += (kPhilox10A); 45 | key.y += (kPhilox10B); 46 | } 47 | uint4 output = philox_single_round(counter, key); 48 | return output; 49 | } 50 | 51 | } // namespace flash 52 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/qpack.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "utils.h" 6 | 7 | namespace quant { 8 | 9 | using namespace cute; 10 | 11 | template 12 | CUTE_DEVICE 13 | void thread_reduce_(Tensor0 const& tensor, Tensor1& summary, Operator& op, const int num_params) { 14 | const int pack_num = size<1>(tensor) / num_params; 15 | 16 | CUTE_UNROLL 17 | for (int mi = 0; mi < size<0>(summary); ++mi) { 18 | int col_start = (mi / 4) * pack_num; 19 | summary(mi) = tensor(mi % 4, col_start); 20 | 21 | CUTE_UNROLL 22 | for (int ni = col_start; ni < col_start + pack_num; ++ni) { 23 | summary(mi) = op(summary(mi), tensor(mi % 4, ni)); 24 | } 25 | 26 | } 27 | 28 | } 29 | 30 | template 31 | __device__ __forceinline__ T warp_reduce(T val, Operator op) { 32 | // Get the thread's position within its group of 4 33 | const int lane_id = threadIdx.x % 32; // Lane ID within warp 34 | const int group_pos = lane_id % 4; // Position within group of 4 35 | 36 | // Only reduce with threads that have the same position in their group of 4 37 | // Using butterfly pattern with xor 38 | for (int mask = 16; mask > 0; mask >>= 1) { 39 | T other = __shfl_xor_sync(0xffffffff, val, mask); 40 | // Only combine if the other thread has the same group_pos 41 | if ((lane_id ^ mask) < 32 && ((lane_id ^ mask) % 4 == group_pos)) { 42 | val = op(val, other); 43 | } 44 | } 45 | return val; 46 | } 47 | 48 | template 49 | CUTE_DEVICE 50 | void allreduce_(Tensor0 &dst, Tensor1 &src, Tensor2 &reduce_tmp, Operator &op) { 51 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 52 | 53 | const int warp_id = threadIdx.x / 32; 54 | const int lane_id = threadIdx.x % 32; 55 | 56 | #pragma unroll 57 | for (int i = 0; i < size(dst); i++) { 58 | // First do reduction within each group of 4 threads 59 | float val = quant::warp_reduce(src(i), op); 60 | // Write the result to shared memory for each group's leader 61 | if (lane_id < 4) { 62 | reduce_tmp(i,warp_id * 4 + lane_id) = val; 63 | } 64 | __syncthreads(); 65 | 66 | // First thread in the first group reads all values and reduces them 67 | if (lane_id < 4) { 68 | float final_val = reduce_tmp(i,0 + lane_id); 69 | #pragma unroll 70 | for (int w = 1; w < 4; w++) { // For 4 warps 71 | final_val = op(final_val, reduce_tmp(i,w * 4 + lane_id)); 72 | } 73 | // Write back the final result 74 | reduce_tmp(i, 0 + lane_id) = final_val; 75 | } 76 | __syncthreads(); 77 | 78 | // All threads read the final result 79 | dst(i) = reduce_tmp(i,0 + lane_id % 4); 80 | 81 | } 82 | 83 | 84 | } 85 | 86 | template 87 | CUTE_DEVICE 88 | void reduce_(Tensor const& tensor, Tensor& summary, Tensor2 &reduce_tmp, Operator& op, const int num_params) { 89 | quant::thread_reduce_(tensor, summary, op, num_params); 90 | quant::allreduce_(summary, summary, reduce_tmp, op); 91 | } 92 | 93 | template 94 | CUTE_DEVICE 95 | void reduce_max(Tensor const& tensor, Tensor &max, Tensor2 &reduce_tmp, const int num_params) { 96 | flash::MaxOp max_op; 97 | quant::reduce_(tensor, max, reduce_tmp, max_op, num_params); // Use the existing reduce_q function 98 | } 99 | 100 | template 101 | CUTE_DEVICE 102 | void reduce_min(Tensor const& tensor, Tensor &min, Tensor2 &reduce_tmp, const int num_params) { 103 | flash::MinOp min_op; 104 | quant::reduce_(tensor, min, reduce_tmp, min_op, num_params); // Use the existing reduce_q function 105 | } 106 | 107 | template 108 | struct qpack_kc_vt; 109 | 110 | template 111 | struct qpack_kc_vt<2, Tensor1, Tensor2, Tensor3, Tensor4, Tensor5> { 112 | static constexpr int num_bits = 2; // Add this line 113 | CUTE_DEVICE static 114 | void apply(Tensor1 &src, Tensor2 &dst, Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, const int num_params) { 115 | const float max_val = float((1 << num_bits) - 1); 116 | const int pack_num = 4 / (num_params / 2); // TODO: check 4 117 | const int num_params_2 = size<1>(src) == 4 ? num_params / 2 : num_params; // TODO: change name? seems hard code? 118 | const int channel_stride = size<0>(src); 119 | 120 | // Declare per-channel tensors 121 | using TensorChannel = decltype(make_fragment_like(scales_k(_, 0))); 122 | TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; 123 | 124 | CUTE_UNROLL 125 | for (int k = 0; k < size<2>(src); ++k) { 126 | // Perform per-channel max and min reductions 127 | quant::reduce_max(src(_, _, k), channel_max, reduce_tmp, num_params_2); 128 | quant::reduce_min(src(_, _, k), channel_min, reduce_tmp, num_params_2); 129 | 130 | // Compute per-channel scale inverses and zeros 131 | CUTE_UNROLL 132 | for (int i = 0; i < size(channel_max); ++i) { 133 | float max_i = float(channel_max(i)); 134 | float min_i = float(channel_min(i)); 135 | float range = max_i - min_i; 136 | // Avoid division by zero 137 | float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; 138 | channel_scales_inv(i) = scale_inv; 139 | channel_zeros(i) = min_i; 140 | // Store scales and zeros 141 | scales_k(i, k) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale 142 | zeros_k(i, k) = min_i; 143 | } 144 | 145 | // Quantize and pack the tensor 146 | CUTE_UNROLL 147 | for (int i = 0; i < size<0>(src); ++i) { 148 | 149 | CUTE_UNROLL 150 | for (int jj = 0; jj < size<1>(src); jj += 8) { 151 | // float val0 = float(src(i, jj, k)); 152 | // float val1 = float(src(i, jj + 1, k)); 153 | // float val2 = float(src(i, jj + 2, k)); 154 | // float val3 = float(src(i, jj + 3, k)); 155 | // float val4 = float(src(i, jj + 4, k)); 156 | // float val5 = float(src(i, jj + 5, k)); 157 | // float val6 = float(src(i, jj + 6, k)); 158 | // float val7 = float(src(i, jj + 7, k)); 159 | 160 | // Load 4 values and convert to float 161 | float val0 = float(src(i, jj, k)) - channel_zeros(i + (jj ) / pack_num * channel_stride); 162 | float val1 = float(src(i, jj + 1, k)) - channel_zeros(i + (jj + 1) / pack_num * channel_stride); 163 | float val2 = float(src(i, jj + 2, k)) - channel_zeros(i + (jj + 2) / pack_num * channel_stride); 164 | float val3 = float(src(i, jj + 3, k)) - channel_zeros(i + (jj + 3) / pack_num * channel_stride); 165 | float val4 = float(src(i, jj + 4, k)) - channel_zeros(i + (jj + 4) / pack_num * channel_stride); 166 | float val5 = float(src(i, jj + 5, k)) - channel_zeros(i + (jj + 5) / pack_num * channel_stride); 167 | float val6 = float(src(i, jj + 6, k)) - channel_zeros(i + (jj + 6) / pack_num * channel_stride); 168 | float val7 = float(src(i, jj + 7, k)) - channel_zeros(i + (jj + 7) / pack_num * channel_stride); 169 | 170 | // Apply scale inverses 171 | val0 *= channel_scales_inv(i + (jj ) / pack_num * channel_stride); 172 | val1 *= channel_scales_inv(i + (jj + 1) / pack_num * channel_stride); 173 | val2 *= channel_scales_inv(i + (jj + 2) / pack_num * channel_stride); 174 | val3 *= channel_scales_inv(i + (jj + 3) / pack_num * channel_stride); 175 | val4 *= channel_scales_inv(i + (jj + 4) / pack_num * channel_stride); 176 | val5 *= channel_scales_inv(i + (jj + 5) / pack_num * channel_stride); 177 | val6 *= channel_scales_inv(i + (jj + 6) / pack_num * channel_stride); 178 | val7 *= channel_scales_inv(i + (jj + 7) / pack_num * channel_stride); 179 | 180 | // Round and clamp the values 181 | val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); 182 | val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); 183 | val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); 184 | val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); 185 | val4 = fminf(fmaxf(roundf(val4), 0.0f), max_val); 186 | val5 = fminf(fmaxf(roundf(val5), 0.0f), max_val); 187 | val6 = fminf(fmaxf(roundf(val6), 0.0f), max_val); 188 | val7 = fminf(fmaxf(roundf(val7), 0.0f), max_val); 189 | 190 | // Pack 8 values (2-bit each) into a 16-bit integer 191 | uint16_t packed = 0; 192 | packed |= (static_cast(static_cast(val7)) & 0x3); // 2 bits 193 | packed <<= 2; 194 | packed |= (static_cast(static_cast(val6)) & 0x3); 195 | packed <<= 2; 196 | packed |= (static_cast(static_cast(val5)) & 0x3); 197 | packed <<= 2; 198 | packed |= (static_cast(static_cast(val4)) & 0x3); 199 | packed <<= 2; 200 | packed |= (static_cast(static_cast(val3)) & 0x3); 201 | packed <<= 2; 202 | packed |= (static_cast(static_cast(val2)) & 0x3); 203 | packed <<= 2; 204 | packed |= (static_cast(static_cast(val1)) & 0x3); 205 | packed <<= 2; 206 | packed |= (static_cast(static_cast(val0)) & 0x3); 207 | 208 | // Store the packed value 209 | dst(i, jj / 8, k) = packed; 210 | } 211 | } 212 | } 213 | 214 | } 215 | 216 | 217 | }; 218 | 219 | template 220 | struct qpack_kc_vt<4, Tensor1, Tensor2, Tensor3, Tensor4, Tensor5> { 221 | static constexpr int num_bits = 4; // Add this line 222 | CUTE_DEVICE static 223 | void apply(Tensor1 &src, Tensor2 &dst, Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, const int num_params) { 224 | const float max_val = float((1 << num_bits) - 1); 225 | const int pack_num = size<1>(src) / num_params; 226 | const int channel_stride = size<0>(src); 227 | 228 | // Declare per-channel tensors 229 | using TensorChannel = decltype(make_fragment_like(scales_k(_, 0))); 230 | TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; 231 | 232 | 233 | CUTE_UNROLL 234 | for (int k = 0; k < size<2>(src); ++k) { 235 | // Perform per-channel max and min reductions 236 | quant::reduce_max(src(_, _, k), channel_max, reduce_tmp, num_params); 237 | quant::reduce_min(src(_, _, k), channel_min, reduce_tmp, num_params); 238 | 239 | // Compute per-channel scale inverses and zeros 240 | CUTE_UNROLL 241 | for (int i = 0; i < size(channel_max); ++i) { 242 | float max_i = float(channel_max(i)); 243 | float min_i = float(channel_min(i)); 244 | float range = max_i - min_i; 245 | // Avoid division by zero 246 | float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; 247 | channel_scales_inv(i) = scale_inv; 248 | channel_zeros(i) = min_i; 249 | // Store scales and zeros 250 | scales_k(i, k) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale 251 | zeros_k(i, k) = min_i; 252 | } 253 | 254 | // Quantize and pack the tensor 255 | CUTE_UNROLL 256 | for (int i = 0; i < size<0>(src); ++i) { 257 | 258 | CUTE_UNROLL 259 | for (int jj = 0; jj < size<1>(src); jj += 4) { 260 | // float val0 = float(src(i, jj, k)); 261 | // float val1 = float(src(i, jj + 1, k)); 262 | // float val2 = float(src(i, jj + 2, k)); 263 | // float val3 = float(src(i, jj + 3, k)); 264 | 265 | // Load 4 values and convert to float 266 | float val0 = float(src(i, jj, k)) - channel_zeros(i + (jj ) / pack_num * channel_stride); 267 | float val1 = float(src(i, jj + 1, k)) - channel_zeros(i + (jj + 1) / pack_num * channel_stride); 268 | float val2 = float(src(i, jj + 2, k)) - channel_zeros(i + (jj + 2) / pack_num * channel_stride); 269 | float val3 = float(src(i, jj + 3, k)) - channel_zeros(i + (jj + 3) / pack_num * channel_stride); 270 | 271 | // Apply scale inverses 272 | val0 *= channel_scales_inv(i + (jj ) / pack_num * channel_stride); 273 | val1 *= channel_scales_inv(i + (jj + 1) / pack_num * channel_stride); 274 | val2 *= channel_scales_inv(i + (jj + 2) / pack_num * channel_stride); 275 | val3 *= channel_scales_inv(i + (jj + 3) / pack_num * channel_stride); 276 | 277 | // Round and clamp the values 278 | val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); 279 | val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); 280 | val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); 281 | val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); 282 | 283 | // Pack the 4 quantized values into a 16-bit integer 284 | uint16_t packed = 0; 285 | packed |= (static_cast(static_cast(val3)) & 0xF); 286 | packed <<= 4; 287 | packed |= (static_cast(static_cast(val2)) & 0xF); 288 | packed <<= 4; 289 | packed |= (static_cast(static_cast(val1)) & 0xF); 290 | packed <<= 4; 291 | packed |= (static_cast(static_cast(val0)) & 0xF); 292 | 293 | // Store the packed value 294 | dst(i, jj / 4, k) = packed; 295 | } 296 | } 297 | } 298 | 299 | } 300 | }; 301 | 302 | template 303 | CUTE_DEVICE 304 | void qpack_Kchannel_Vtensor(Tensor1 &src, Tensor2 &dst, 305 | Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, 306 | const int num_params = 1) { 307 | 308 | qpack_kc_vt::apply(src, dst, scales_k, zeros_k, reduce_tmp, num_params); 309 | 310 | } 311 | 312 | //////////////////////////////////////////////////////////////////////////////////////////////////// 313 | 314 | template 315 | CUTE_DEVICE 316 | void quad_allreduce_g(TensorParamsG0 &dst, Tensor1 &src, Operator &op, int k, int num_params) { 317 | CUTE_UNROLL 318 | for (int i = k * num_params; i < (k + 1) * num_params; i++) { 319 | 320 | // Calculate which group of 4 this thread belongs to 321 | const int group_id = threadIdx.x / 4; 322 | const int group_base = group_id * 4; 323 | 324 | // Start with the value from the first thread in our group 325 | auto val = __shfl_sync(uint32_t(-1), src(i), group_base); 326 | 327 | // Reduce with the other 3 threads in our group 328 | #pragma unroll 329 | for (int offset = 1; offset < 4; offset++) { 330 | val = op(val, __shfl_sync(uint32_t(-1), src(i), group_base + offset)); 331 | } 332 | 333 | // Broadcast the final result back to all threads in the group 334 | dst(i) = val; 335 | 336 | } 337 | } 338 | 339 | template 340 | CUTE_DEVICE 341 | void thread_reduce_g(Tensor0 const& tensor, TensorParamsG0& summary, Operator& op, int k, int num_params) { 342 | CUTE_UNROLL 343 | for (int i = k * num_params, j = 0; i < (k + 1) * num_params; i++, j++) { 344 | int ii = size<1>(tensor) / num_params; 345 | summary(i) = tensor(0, j * ii); 346 | 347 | CUTE_UNROLL 348 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 349 | CUTE_UNROLL 350 | for (int ni = j * ii; ni < (j + 1) * ii; ++ni) { 351 | summary(i) = op(summary(i), tensor(mi, ni)); 352 | } 353 | } 354 | } 355 | } 356 | 357 | template 358 | CUTE_DEVICE 359 | void reduce_g(Tensor const& tensor, TensorParamsG0& summary, Operator& op, int k, int num_params) { 360 | quant::thread_reduce_g(tensor, summary, op, k, num_params); 361 | quant::quad_allreduce_g(summary, summary, op, k, num_params); 362 | } 363 | 364 | template 365 | CUTE_DEVICE 366 | void reduce_max_g(Tensor const& tensor, TensorParamsG0 &max, int k, int num_params) { 367 | flash::MaxOp max_op; 368 | quant::reduce_g(tensor, max, max_op, k, num_params); // Use the existing reduce_q function 369 | } 370 | 371 | template 372 | CUTE_DEVICE 373 | void reduce_min_g(Tensor const& tensor, TensorParamsG0 &min, int k, int num_params) { 374 | flash::MinOp min_op; 375 | quant::reduce_g(tensor, min, min_op, k, num_params); // Use the existing reduce_q function 376 | } 377 | 378 | template 379 | CUTE_DEVICE 380 | void quant_Ktensor(Tensor1 &src, Tensor2 &dst, 381 | TensorParamsG1 &scales_k_g, TensorParamsG2 &zeros_k_g, 382 | const int num_params) { 383 | 384 | const int num_bits = 4; 385 | 386 | const float max_val = float((1 << num_bits) - 1); 387 | // const int num_params = 128 / group_size; 388 | const int ki = size<2>(src) / num_params; 389 | 390 | // Declare per-channel tensors 391 | using TensorChannel = decltype(make_fragment_like(scales_k_g)); 392 | TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; 393 | 394 | CUTE_UNROLL 395 | for (int k = 0; k < size<1>(src); ++k) { 396 | quant::reduce_max_g(src(_, k, _), channel_max, k, num_params); // TODO:check 128 397 | quant::reduce_min_g(src(_, k, _), channel_min, k, num_params); 398 | } 399 | 400 | // Compute per-channel scale inverses and zeros 401 | CUTE_UNROLL 402 | for (int i = 0; i < size(channel_max); ++i) { 403 | float max_i = float(channel_max(i)); 404 | float min_i = float(channel_min(i)); 405 | float range = max_i - min_i; 406 | // Avoid division by zero 407 | float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; 408 | channel_scales_inv(i) = scale_inv; 409 | channel_zeros(i) = min_i; 410 | // Store scales and zeros 411 | scales_k_g(i) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale 412 | zeros_k_g(i) = min_i; 413 | } 414 | 415 | // Pack the tensor 416 | CUTE_UNROLL 417 | for (int k = 0; k < size<2>(src); ++k) { 418 | 419 | CUTE_UNROLL 420 | for (int i = 0; i < size<0>(src); ++i) { 421 | 422 | CUTE_UNROLL 423 | for (int jj = 0; jj < size<1>(src); jj += 4) { 424 | float zero0 = float(channel_zeros(k / ki + jj + 0 * num_params)); 425 | float zero1 = float(channel_zeros(k / ki + jj + 1 * num_params)); 426 | float zero2 = float(channel_zeros(k / ki + jj + 2 * num_params)); 427 | float zero3 = float(channel_zeros(k / ki + jj + 3 * num_params)); 428 | 429 | float scale_inv0 = float(channel_scales_inv(k / ki + jj + 0 * num_params)); 430 | float scale_inv1 = float(channel_scales_inv(k / ki + jj + 1 * num_params)); 431 | float scale_inv2 = float(channel_scales_inv(k / ki + jj + 2 * num_params)); 432 | float scale_inv3 = float(channel_scales_inv(k / ki + jj + 3 * num_params)); 433 | 434 | // float val0 = float(src(i, jj, k)); 435 | // float val1 = float(src(i, jj + 1, k)); 436 | // float val2 = float(src(i, jj + 2, k)); 437 | // float val3 = float(src(i, jj + 3, k)); 438 | 439 | float val0 = float(src(i, jj, k)) - zero0; 440 | float val1 = float(src(i, jj + 1, k)) - zero1; 441 | float val2 = float(src(i, jj + 2, k)) - zero2; 442 | float val3 = float(src(i, jj + 3, k)) - zero3; 443 | 444 | val0 *= scale_inv0; 445 | val1 *= scale_inv1; 446 | val2 *= scale_inv2; 447 | val3 *= scale_inv3; 448 | 449 | // Round and clamp the values 450 | val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); 451 | val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); 452 | val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); 453 | val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); 454 | 455 | // Pack the 4 quantized values into a 16-bit integer 456 | uint16_t packed = 0; 457 | packed |= (static_cast(static_cast(val3)) & 0xF); 458 | packed <<= 4; 459 | packed |= (static_cast(static_cast(val2)) & 0xF); 460 | packed <<= 4; 461 | packed |= (static_cast(static_cast(val1)) & 0xF); 462 | packed <<= 4; 463 | packed |= (static_cast(static_cast(val0)) & 0xF); 464 | 465 | // Store the packed value 466 | dst(i, jj / 4, k) = packed; 467 | } 468 | } 469 | } 470 | } 471 | 472 | //////////////////////////////////////////////////////////////////////////////////////////////////// 473 | 474 | template 477 | CUTE_DEVICE 478 | void pack_Ktensor_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, 479 | TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, 480 | Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, 481 | const int num_params) { 482 | // copy from register to shared memory 483 | cute::copy(smem_tiled_copy, src_r2s, dst_r2s); 484 | __syncthreads(); 485 | 486 | // copy from shared memory to global memory 487 | cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); 488 | __syncthreads(); 489 | 490 | // copy params from register to global memory 491 | CUTE_UNROLL 492 | for (int j = 0; j < size<0>(scales); ++j) { 493 | params(0 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = scales(j); 494 | params(64 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = zeros(j); 495 | } 496 | __syncthreads(); 497 | } 498 | 499 | template 502 | CUTE_DEVICE 503 | void pack_Kchannel_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, 504 | TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, 505 | Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, 506 | const int num_params) { 507 | // copy from register to shared memory 508 | cute::copy(smem_tiled_copy, src_r2s, dst_r2s); 509 | __syncthreads(); 510 | 511 | // copy from shared memory to global memory 512 | cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); 513 | __syncthreads(); 514 | 515 | // copy params from register to global memory 516 | CUTE_UNROLL 517 | for (int i = 0; i < size<1>(scales); ++i) { 518 | CUTE_UNROLL 519 | for (int j = 0; j < size<0>(scales); ++j) { 520 | params(j % num_params, 0 + 8 * i + 4 * (j / num_params) + threadIdx.x % 4) = scales(j, i); 521 | params(j % num_params, 64 + 8 * i + 4 * (j / num_params) + threadIdx.x % 4) = zeros(j, i); 522 | } 523 | } 524 | __syncthreads(); 525 | } 526 | 527 | template 531 | CUTE_DEVICE 532 | void pack_Vtensor_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, 533 | TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, 534 | Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, 535 | const int num_params) { 536 | if (kHeadDim == 128 && num_bits == 2) { 537 | if (threadIdx.x < 64) { 538 | cute::copy(smem_tiled_copy, src_r2s, dst_r2s); 539 | } 540 | } else { 541 | cute::copy(smem_tiled_copy, src_r2s, dst_r2s); 542 | } 543 | __syncthreads(); 544 | 545 | // copy from shared memory to global memory 546 | cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); 547 | __syncthreads(); 548 | 549 | // copy params from register to global memory 550 | const int num_params_2 = num_bits == 2 ? num_params / 2 : num_params; 551 | CUTE_UNROLL 552 | for (int i = 0; i < size<1>(scales); ++i) { 553 | CUTE_UNROLL 554 | for (int j = 0; j < size<0>(scales); ++j) { 555 | params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + threadIdx.x % 4, j % num_params_2) = scales(j, i); 556 | params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + threadIdx.x % 4, j % num_params_2) = zeros(j, i); 557 | } 558 | } 559 | __syncthreads(); 560 | } 561 | 562 | } // namespace quant 563 | 564 | 565 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/rotary.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "utils.h" 10 | 11 | //////////////////////////////////////////////////////////////////////////////////////////////////// 12 | 13 | namespace flash { 14 | 15 | using namespace cute; 16 | 17 | //////////////////////////////////////////////////////////////////////////////////////////////////// 18 | 19 | template 22 | __forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, 23 | Tensor &D, 24 | Tensor const &Cos, 25 | Tensor const &Sin, 26 | Tensor const &identity_MN, 27 | const int max_MN, const int min_MN, 28 | const int dim, const int rotary_dim) { 29 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 30 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 31 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 32 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 33 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 34 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M 35 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K 36 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M 37 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K 38 | CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K 39 | static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); 40 | static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 41 | Tensor rCos = make_fragment_like(Cos); 42 | Tensor rSin = make_fragment_like(Sin); 43 | Tensor rS = make_fragment_like(S); 44 | #pragma unroll 45 | for (int m = 0; m < size<1>(S); ++m) { 46 | if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { 47 | #pragma unroll 48 | for (int k = 0; k < size<2>(S); ++k) { 49 | if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { 50 | cute::copy(S(_, m, k), rS(_, m, k)); 51 | if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { 52 | cute::copy(Cos(_, m, k), rCos(_, m, k)); 53 | cute::copy(Sin(_, m, k), rSin(_, m, k)); 54 | Tensor S_fp32 = convert_type(rS(_, m, k)); 55 | Tensor cos_fp32 = convert_type(rCos(_, m, k)); 56 | Tensor sin_fp32 = convert_type(rSin(_, m, k)); 57 | #pragma unroll 58 | for (int i = 0; i < size<0>(rS) / 2; ++i) { 59 | float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); 60 | float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); 61 | S_fp32(2 * i) = real; 62 | S_fp32(2 * i + 1) = imag; 63 | } 64 | // Idk but I need to copy for the convert_type to work 65 | Tensor S_fp32_copy = make_fragment_like(S_fp32); 66 | cute::copy(S_fp32, S_fp32_copy); 67 | using T = typename Engine0::value_type; 68 | Tensor S_og_type = convert_type(S_fp32_copy); 69 | cute::copy(S_og_type, rS(_, m, k)); 70 | } 71 | cute::copy(rS(_, m, k), D(_, m, k)); 72 | } else if (Clear_OOB_K) { 73 | cute::clear(D(_, m, k)); 74 | } 75 | } 76 | } 77 | } 78 | } 79 | 80 | //////////////////////////////////////////////////////////////////////////////////////////////////// 81 | 82 | template 85 | __forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, 86 | Tensor &D, 87 | Tensor const &Cos, 88 | Tensor const &Sin, 89 | Tensor const &identity_MN, 90 | const int max_MN, const int min_MN, 91 | const int dim, const int rotary_dim) { 92 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 93 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 94 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 95 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 96 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 97 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M 98 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K 99 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M 100 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K 101 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA 102 | CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); 103 | static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 104 | Tensor rCos = make_fragment_like(Cos); 105 | Tensor rSin = make_fragment_like(Sin); 106 | Tensor rS = make_fragment_like(S); 107 | Tensor rS_other = make_fragment_like(rS(_, 0, 0)); 108 | #pragma unroll 109 | for (int m = 0; m < size<1>(S); ++m) { 110 | if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { 111 | #pragma unroll 112 | for (int k = 0; k < size<2>(S); ++k) { 113 | if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { 114 | cute::copy(S(_, m, k), rS(_, m, k)); 115 | if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { 116 | const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; 117 | Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); 118 | cute::copy(gS_other, rS_other); 119 | // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } 120 | Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); 121 | Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); 122 | cute::copy(gCos, rCos(_, m, k)); 123 | cute::copy(gSin, rSin(_, m, k)); 124 | // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } 125 | Tensor S_fp32 = convert_type(rS(_, m, k)); 126 | Tensor S_other_fp32 = convert_type(rS_other); 127 | Tensor cos_fp32 = convert_type(rCos(_, m, k)); 128 | Tensor sin_fp32 = convert_type(rSin(_, m, k)); 129 | #pragma unroll 130 | for (int i = 0; i < size<0>(rS); ++i) { 131 | S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); 132 | } 133 | // Idk but I need to copy for the convert_type to work 134 | Tensor S_fp32_copy = make_fragment_like(S_fp32); 135 | cute::copy(S_fp32, S_fp32_copy); 136 | using T = typename Engine0::value_type; 137 | Tensor S_og_type = convert_type(S_fp32_copy); 138 | cute::copy(S_og_type, rS(_, m, k)); 139 | // if (cute::thread0()) { print_tensor(rS(_, m, k)); } 140 | } 141 | cute::copy(rS(_, m, k), D(_, m, k)); 142 | } else if (Clear_OOB_K) { 143 | cute::clear(D(_, m, k)); 144 | } 145 | } 146 | } 147 | } 148 | } 149 | 150 | //////////////////////////////////////////////////////////////////////////////////////////////////// 151 | 152 | } // namespace flash 153 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/softmax.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | 13 | #include "philox.cuh" 14 | #include "utils.h" 15 | 16 | namespace flash { 17 | 18 | using namespace cute; 19 | 20 | //////////////////////////////////////////////////////////////////////////////////////////////////// 21 | 22 | template 23 | __device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { 24 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 25 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 26 | CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); 27 | #pragma unroll 28 | for (int mi = 0; mi < size<0>(tensor); mi++) { 29 | summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); 30 | #pragma unroll 31 | for (int ni = 1; ni < size<1>(tensor); ni++) { 32 | summary(mi) = op(summary(mi), tensor(mi, ni)); 33 | } 34 | } 35 | 36 | } 37 | 38 | // template 39 | // __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { 40 | // CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 41 | // #pragma unroll 42 | // for (int i = 0; i < size(dst); i++){ 43 | // dst(i) = Allreduce<4>::run(src(i), op); 44 | // } 45 | // } 46 | 47 | template 48 | __device__ __forceinline__ float warp_reduce_acc(float &val, Operator &op) { 49 | // Get the thread's position within its group of 4 50 | const int group_id = threadIdx.x / 4; // Which group of 4 this thread belongs to 51 | const int local_id = threadIdx.x % 4; // Position within group of 4 (0-3) 52 | 53 | // Only reduce within groups of 4 threads 54 | // Using butterfly pattern 55 | #pragma unroll 56 | for (int offset = 2; offset > 0; offset >>= 1) { 57 | float other = __shfl_down_sync(0xffffffff, val, offset); 58 | if (local_id < offset) { 59 | val = op(val, other); 60 | } 61 | } 62 | 63 | // Broadcast the result from thread 0 to all threads in the group 64 | val = __shfl_sync(0xffffffff, val, group_id * 4); 65 | 66 | return val; 67 | } 68 | 69 | template 70 | __device__ __forceinline__ void quad_allreduce_2(Tensor &dst, Tensor &src, Tensor2 &reduce_tmp, Operator &op) { 71 | // __shared__ float smem[4]; // For 4 warps, we need 4 elements 72 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 73 | 74 | const int warp_id = threadIdx.x / 32; 75 | const int lane_id = threadIdx.x % 32; 76 | const int row = (threadIdx.x % 32) / 4; 77 | 78 | // #if DEBUG 79 | // if (threadIdx.x == 103 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 7) { 80 | // PRINTTENSOR("dst", dst); 81 | // } 82 | // #endif 83 | 84 | #pragma unroll 85 | for (int i = 0; i < 1; i++) { 86 | // First do reduction within each group of 4 threads 87 | float val = warp_reduce_acc(src(i), op); 88 | 89 | // Write the result to shared memory for each group's leader 90 | if (lane_id % 4 == 0) { 91 | reduce_tmp(row,warp_id) = val; 92 | } 93 | __syncthreads(); 94 | 95 | // Check if thread is one of the first threads in each group of 4 (0,4,8,12,16,20,24,28) 96 | if ((lane_id % 4) == 0) { 97 | // This thread is responsible for reducing its group's values 98 | float group_val = reduce_tmp(row, 0); 99 | #pragma unroll 100 | for (int w = 1; w < 4; w++) { 101 | group_val = op(group_val, reduce_tmp(row, w)); 102 | } 103 | reduce_tmp(row, 0) = group_val; 104 | } 105 | __syncthreads(); 106 | 107 | // All threads read the final result 108 | dst(i) = reduce_tmp(row,0); 109 | 110 | // #if DEBUG 111 | // if (threadIdx.x == 103 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 7) { 112 | // printf("val: %f\n", val); 113 | // PRINTTENSOR("reduce_tmp", reduce_tmp); 114 | // } 115 | // #endif 116 | 117 | } 118 | } 119 | 120 | template 121 | __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Tensor2 &reduce_tmp, Operator &op) { 122 | // __shared__ float smem[4]; // For 4 warps, we need 4 elements 123 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 124 | 125 | const int warp_id = threadIdx.x / 32; 126 | const int lane_id = threadIdx.x % 32; 127 | 128 | #pragma unroll 129 | for (int i = 0; i < size(dst); i++) { 130 | // First do reduction within each group of 4 threads 131 | float val = Allreduce<4>::run(src(i), op); 132 | 133 | // Write the result to shared memory for each group's leader 134 | if (lane_id % 4 == 0) { 135 | reduce_tmp(i,warp_id) = val; 136 | } 137 | __syncthreads(); 138 | 139 | // First thread in the first group reads all values and reduces them 140 | if (lane_id == 0) { 141 | float final_val = reduce_tmp(0,0); 142 | #pragma unroll 143 | for (int w = 1; w < 4; w++) { // For 4 warps 144 | final_val = op(final_val, reduce_tmp(i,w)); 145 | } 146 | // Write back the final result 147 | reduce_tmp(i,0) = final_val; 148 | } 149 | __syncthreads(); 150 | 151 | // All threads read the final result 152 | // cute::copy(reduce_tmp(0,0), dst(i)); 153 | dst(i) = reduce_tmp(i,0); 154 | } 155 | } 156 | 157 | template 158 | __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Tensor2 &reduce_tmp, Operator &op) { 159 | thread_reduce_(tensor, summary, op); 160 | quad_allreduce_(summary, summary, reduce_tmp, op); 161 | } 162 | 163 | template 164 | __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max, Tensor2 &reduce_tmp){ 165 | MaxOp max_op; 166 | reduce_(tensor, max, reduce_tmp, max_op); 167 | } 168 | 169 | template 170 | __device__ __forceinline__ void reduce_2(Tensor const& tensor, Tensor &summary, Tensor2 &reduce_tmp, Operator &op) { 171 | thread_reduce_(tensor, summary, op); 172 | quad_allreduce_2(summary, summary, reduce_tmp, op); 173 | } 174 | 175 | template 176 | __device__ __forceinline__ void reduce_max_2(Tensor const& tensor, Tensor &max, Tensor2 &reduce_tmp){ 177 | MaxOp max_op; 178 | reduce_2(tensor, max, reduce_tmp, max_op); 179 | } 180 | 181 | template 182 | __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ 183 | SumOp sum_op; 184 | thread_reduce_(tensor, sum, sum_op); 185 | } 186 | 187 | // Apply the exp to all the elements. 188 | template 189 | __forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { 190 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 191 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 192 | CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); 193 | #pragma unroll 194 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 195 | // If max is -inf, then all elements must have been -inf (possibly due to masking). 196 | // We don't want (-inf - (-inf)) since that would give NaN. 197 | // If we don't have float around M_LOG2E the multiplication is done in fp64. 198 | const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); 199 | #pragma unroll 200 | for (int ni = 0; ni < size<1>(tensor); ++ni) { 201 | // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - 202 | // max * log_2(e)) This allows the compiler to use the ffma 203 | // instruction instead of fadd and fmul separately. 204 | // The following macro will disable the use of fma. 205 | // See: https://github.com/pytorch/pytorch/issues/121558 for more details 206 | // This macro is set in PyTorch and not FlashAttention 207 | #ifdef UNFUSE_FMA 208 | tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); 209 | #else 210 | tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); 211 | #endif 212 | } 213 | } 214 | } 215 | 216 | // Apply the exp to all the elements. 217 | template 218 | __forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { 219 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 220 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 221 | CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); 222 | #pragma unroll 223 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 224 | MaxOp max_op; 225 | max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); 226 | #pragma unroll 227 | for (int ni = 1; ni < size<1>(tensor); ni++) { 228 | max(mi) = max_op(max(mi), tensor(mi, ni)); 229 | } 230 | max(mi) = Allreduce<4>::run(max(mi), max_op); 231 | // If max is -inf, then all elements must have been -inf (possibly due to masking). 232 | // We don't want (-inf - (-inf)) since that would give NaN. 233 | const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; 234 | sum(mi) = 0; 235 | #pragma unroll 236 | for (int ni = 0; ni < size<1>(tensor); ++ni) { 237 | // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - 238 | // max * log_2(e)) This allows the compiler to use the ffma 239 | // instruction instead of fadd and fmul separately. 240 | tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); 241 | sum(mi) += tensor(mi, ni); 242 | } 243 | SumOp sum_op; 244 | sum(mi) = Allreduce<4>::run(sum(mi), sum_op); 245 | } 246 | } 247 | 248 | //////////////////////////////////////////////////////////////////////////////////////////////////// 249 | 250 | template 251 | struct Softmax { 252 | 253 | using TensorT = decltype(make_tensor(Shape>{})); 254 | TensorT row_max, row_sum; 255 | 256 | __forceinline__ __device__ Softmax() {}; 257 | 258 | template 259 | __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &reduce_tmp, float softmax_scale_log2) { 260 | // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) 261 | Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); 262 | static_assert(decltype(size<0>(scores))::value == kNRows); 263 | if (Is_first) { 264 | flash::template reduce_max_2(scores, row_max, reduce_tmp); 265 | flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); 266 | flash::reduce_sum(scores, row_sum); 267 | } else { 268 | Tensor scores_max_prev = make_fragment_like(row_max); 269 | cute::copy(row_max, scores_max_prev); 270 | flash::template reduce_max_2(scores, row_max, reduce_tmp); 271 | // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) 272 | Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); 273 | static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); 274 | #pragma unroll 275 | for (int mi = 0; mi < size(row_max); ++mi) { 276 | float scores_max_cur = !Check_inf 277 | ? row_max(mi) 278 | : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); 279 | float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); 280 | row_sum(mi) *= scores_scale; 281 | #pragma unroll 282 | for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } 283 | } 284 | 285 | flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); 286 | // We don't do the reduce across threads here since we don't need to use the row_sum. 287 | // We do that reduce at the end when we need to normalize the softmax. 288 | flash::reduce_sum(scores, row_sum); 289 | } 290 | }; 291 | 292 | template 293 | __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, Tensor1 &reduce_tmp, float softmax_scale, float rp_dropout=1.0) { 294 | SumOp sum_op; 295 | quad_allreduce_2(row_sum, row_sum, reduce_tmp, sum_op); 296 | TensorT lse = make_fragment_like(row_sum); 297 | Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); 298 | static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); 299 | #pragma unroll 300 | for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { 301 | float sum = row_sum(mi); 302 | float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; 303 | lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); 304 | float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; 305 | #pragma unroll 306 | for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } 307 | } 308 | return lse; 309 | }; 310 | }; 311 | 312 | } // namespace flash 313 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, [&] { 14 | /// some_function(...); 15 | /// }); 16 | /// ``` 17 | 18 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 19 | [&] { \ 20 | if (COND) { \ 21 | constexpr static bool CONST_NAME = true; \ 22 | return __VA_ARGS__(); \ 23 | } else { \ 24 | constexpr static bool CONST_NAME = false; \ 25 | return __VA_ARGS__(); \ 26 | } \ 27 | }() 28 | 29 | #ifdef FLASHATTENTION_DISABLE_DROPOUT 30 | #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ 31 | [&] { \ 32 | constexpr static bool CONST_NAME = false; \ 33 | return __VA_ARGS__(); \ 34 | }() 35 | #else 36 | #define DROPOUT_SWITCH BOOL_SWITCH 37 | #endif 38 | 39 | #ifdef FLASHATTENTION_DISABLE_ALIBI 40 | #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ 41 | [&] { \ 42 | constexpr static bool CONST_NAME = false; \ 43 | return __VA_ARGS__(); \ 44 | }() 45 | #else 46 | #define ALIBI_SWITCH BOOL_SWITCH 47 | #endif 48 | 49 | #ifdef FLASHATTENTION_DISABLE_UNEVEN_K 50 | #define EVENK_SWITCH(COND, CONST_NAME, ...) \ 51 | [&] { \ 52 | constexpr static bool CONST_NAME = true; \ 53 | return __VA_ARGS__(); \ 54 | }() 55 | #else 56 | #define EVENK_SWITCH BOOL_SWITCH 57 | #endif 58 | 59 | #ifdef FLASHATTENTION_DISABLE_SOFTCAP 60 | #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ 61 | [&] { \ 62 | constexpr static bool CONST_NAME = false; \ 63 | return __VA_ARGS__(); \ 64 | }() 65 | #else 66 | #define SOFTCAP_SWITCH BOOL_SWITCH 67 | #endif 68 | 69 | #ifdef FLASHATTENTION_DISABLE_LOCAL 70 | #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ 71 | [&] { \ 72 | constexpr static bool CONST_NAME = false; \ 73 | return __VA_ARGS__(); \ 74 | }() 75 | #else 76 | #define LOCAL_SWITCH BOOL_SWITCH 77 | #endif 78 | 79 | #define FP16_SWITCH(COND, ...) \ 80 | [&] { \ 81 | if (COND) { \ 82 | using elem_type = cutlass::half_t; \ 83 | return __VA_ARGS__(); \ 84 | } else { \ 85 | using elem_type = cutlass::bfloat16_t; \ 86 | return __VA_ARGS__(); \ 87 | } \ 88 | }() 89 | 90 | // TODO 91 | // #define HEADDIM_SWITCH(HEADDIM, ...) \ 92 | // [&] { \ 93 | // if (HEADDIM <= 32) { \ 94 | // constexpr static int kHeadDim = 32; \ 95 | // return __VA_ARGS__(); \ 96 | // } else if (HEADDIM <= 64) { \ 97 | // constexpr static int kHeadDim = 64; \ 98 | // return __VA_ARGS__(); \ 99 | // } else if (HEADDIM <= 96) { \ 100 | // constexpr static int kHeadDim = 96; \ 101 | // return __VA_ARGS__(); \ 102 | // } else if (HEADDIM <= 128) { \ 103 | // constexpr static int kHeadDim = 128; \ 104 | // return __VA_ARGS__(); \ 105 | // } else if (HEADDIM <= 160) { \ 106 | // constexpr static int kHeadDim = 160; \ 107 | // return __VA_ARGS__(); \ 108 | // } else if (HEADDIM <= 192) { \ 109 | // constexpr static int kHeadDim = 192; \ 110 | // return __VA_ARGS__(); \ 111 | // } else if (HEADDIM <= 224) { \ 112 | // constexpr static int kHeadDim = 224; \ 113 | // return __VA_ARGS__(); \ 114 | // } else if (HEADDIM <= 256) { \ 115 | // constexpr static int kHeadDim = 256; \ 116 | // return __VA_ARGS__(); \ 117 | // } \ 118 | // }() 119 | 120 | // #define HEADDIM_SWITCH(HEADDIM, ...) \ 121 | // [&] { \ 122 | // if (HEADDIM <= 32) { \ 123 | // constexpr static int kHeadDim = 128; \ 124 | // return __VA_ARGS__(); \ 125 | // } else if (HEADDIM <= 64) { \ 126 | // constexpr static int kHeadDim = 128; \ 127 | // return __VA_ARGS__(); \ 128 | // } else if (HEADDIM <= 96) { \ 129 | // constexpr static int kHeadDim = 128; \ 130 | // return __VA_ARGS__(); \ 131 | // } else if (HEADDIM <= 128) { \ 132 | // constexpr static int kHeadDim = 128; \ 133 | // return __VA_ARGS__(); \ 134 | // } else if (HEADDIM <= 160) { \ 135 | // constexpr static int kHeadDim = 128; \ 136 | // return __VA_ARGS__(); \ 137 | // } else if (HEADDIM <= 192) { \ 138 | // constexpr static int kHeadDim = 128; \ 139 | // return __VA_ARGS__(); \ 140 | // } else if (HEADDIM <= 224) { \ 141 | // constexpr static int kHeadDim = 128; \ 142 | // return __VA_ARGS__(); \ 143 | // } else if (HEADDIM <= 256) { \ 144 | // constexpr static int kHeadDim = 128; \ 145 | // return __VA_ARGS__(); \ 146 | // } \ 147 | // }() 148 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/include/utils.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 14 | #include 15 | #endif 16 | 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "dequantize.h" 24 | 25 | #define PRINT(name, content) \ 26 | print(name); \ 27 | print(" : "); \ 28 | print(content); \ 29 | print("\n"); 30 | 31 | #define PRINTTENSOR(name, content) \ 32 | print(name); \ 33 | print(" : "); \ 34 | print_tensor(content); \ 35 | print("\n"); 36 | 37 | //////////////////////////////////////////////////////////////////////////////////////////////////// 38 | 39 | namespace flash { 40 | 41 | //////////////////////////////////////////////////////////////////////////////////////////////////// 42 | 43 | template 44 | __forceinline__ __device__ uint32_t relu2(const uint32_t x); 45 | 46 | template<> 47 | __forceinline__ __device__ uint32_t relu2(const uint32_t x) { 48 | uint32_t res; 49 | const uint32_t zero = 0u; 50 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 51 | asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); 52 | #else 53 | asm volatile( \ 54 | "{\n" \ 55 | "\t .reg .f16x2 sela;\n" \ 56 | "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ 57 | "\t and.b32 %0, sela, %1;\n" 58 | "}\n" : "=r"(res) : "r"(x), "r"(zero)); 59 | #endif 60 | return res; 61 | } 62 | 63 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 64 | template<> 65 | __forceinline__ __device__ uint32_t relu2(const uint32_t x) { 66 | uint32_t res; 67 | const uint32_t zero = 0u; 68 | asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); 69 | return res; 70 | } 71 | #endif 72 | 73 | //////////////////////////////////////////////////////////////////////////////////////////////////// 74 | 75 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 76 | 77 | template 78 | __forceinline__ __device__ uint32_t convert_relu2(const float2 x); 79 | 80 | template<> 81 | __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { 82 | uint32_t res; 83 | const uint32_t a = reinterpret_cast(x.x); 84 | const uint32_t b = reinterpret_cast(x.y); 85 | asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); 86 | return res; 87 | } 88 | 89 | template<> 90 | __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { 91 | uint32_t res; 92 | const uint32_t a = reinterpret_cast(x.x); 93 | const uint32_t b = reinterpret_cast(x.y); 94 | asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); 95 | return res; 96 | } 97 | 98 | #endif 99 | 100 | //////////////////////////////////////////////////////////////////////////////////////////////////// 101 | 102 | template 103 | struct MaxOp { 104 | __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } 105 | }; 106 | 107 | template <> 108 | struct MaxOp { 109 | // This is slightly faster 110 | __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } 111 | }; 112 | 113 | template 114 | struct MinOp { 115 | __device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; } 116 | }; 117 | 118 | template <> 119 | struct MinOp { 120 | // This is slightly faster 121 | __device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); } 122 | }; 123 | 124 | //////////////////////////////////////////////////////////////////////////////////////////////////// 125 | 126 | template 127 | struct SumOp { 128 | __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } 129 | }; 130 | 131 | //////////////////////////////////////////////////////////////////////////////////////////////////// 132 | 133 | template 134 | struct Allreduce { 135 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 136 | template 137 | static __device__ __forceinline__ T run(T x, Operator &op) { 138 | constexpr int OFFSET = THREADS / 2; 139 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 140 | return Allreduce::run(x, op); 141 | } 142 | }; 143 | 144 | //////////////////////////////////////////////////////////////////////////////////////////////////// 145 | 146 | template<> 147 | struct Allreduce<2> { 148 | template 149 | static __device__ __forceinline__ T run(T x, Operator &op) { 150 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 151 | return x; 152 | } 153 | }; 154 | 155 | //////////////////////////////////////////////////////////////////////////////////////////////////// 156 | 157 | template 161 | __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 162 | Tensor4 const& tCsB, TiledMma tiled_mma, 163 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 164 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 165 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 166 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 167 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 168 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 169 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 170 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 171 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 172 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 173 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } 174 | #pragma unroll 175 | for (int i = 0; i < size<2>(tCrA); ++i) { 176 | if (i < size<2>(tCrA) - 1) { 177 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 178 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } 179 | } 180 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 181 | } 182 | } 183 | 184 | //////////////////////////////////////////////////////////////////////////////////////////////////// 185 | 186 | template 198 | __forceinline__ __device__ void gemm_Vtensor(Tensor0 &acc, Tensor1 &tCrA, 199 | Tensor2_i4 &tCrB_i4, Tensor2_dequant &tCrB_dequant, 200 | Tensor2_scales &tCrB_scales, Tensor2_zeros &tCrB_zeros, Tensor2_params &sV_params, 201 | Tensor3 const& tCsA, 202 | Tensor4_i4 const& tCsB_i4, 203 | TiledMma tiled_mma, 204 | TiledCopyA smem_tiled_copy_A, 205 | TiledCopyB_i4 smem_tiled_copy_B_i4, 206 | ThrCopyA smem_thr_copy_A, 207 | ThrCopyB_i4 smem_thr_copy_B_i4, 208 | const int num_params) { 209 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 210 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 211 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 212 | Tensor tCrB_i4_copy_view = smem_thr_copy_B_i4.retile_D(tCrB_i4); 213 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 214 | if (!B_in_regs) { 215 | cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, _0{}), tCrB_i4_copy_view(_, _, _0{})); 216 | quant::load_params_Vtensor(tCrB_scales, tCrB_zeros, sV_params, threadIdx.x, 0, num_params); 217 | quant::dequant_Kchannel_Vtensor(tCrB_i4(_,_,_0{}), tCrB_dequant(_,_,_0{}), tCrB_scales(_,_0{}), tCrB_zeros(_,_0{}), num_params); 218 | } 219 | 220 | #pragma unroll 221 | for (int i = 0; i < size<2>(tCrA); ++i) { 222 | if (i < size<2>(tCrA) - 1) { 223 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 224 | if (!B_in_regs) { 225 | cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, i + 1), tCrB_i4_copy_view(_, _, i + 1)); 226 | quant::load_params_Vtensor(tCrB_scales, tCrB_zeros, sV_params, threadIdx.x, i + 1, num_params); 227 | quant::dequant_Kchannel_Vtensor(tCrB_i4(_,_, i + 1), tCrB_dequant(_,_, i + 1), tCrB_scales(_,i + 1), tCrB_zeros(_, i + 1), num_params); 228 | } 229 | } 230 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB_dequant(_, _, i), acc); 231 | } 232 | 233 | } 234 | 235 | //////////////////////////////////////////////////////////////////////////////////////////////////// 236 | 237 | template 249 | __forceinline__ __device__ void gemm_Kchannel(Tensor0 &acc, Tensor1 &tCrA, 250 | Tensor2_i4 &tCrB_i4, Tensor2_dequant &tCrB_dequant, 251 | Tensor2_scales &tCrB_scales, Tensor2_zeros &tCrB_zeros, Tensor2_params &sK_params, 252 | Tensor3 const& tCsA, 253 | Tensor4_i4 const& tCsB_i4, 254 | TiledMma tiled_mma, 255 | TiledCopyA smem_tiled_copy_A, 256 | TiledCopyB_i4 smem_tiled_copy_B_i4, 257 | ThrCopyA smem_thr_copy_A, 258 | ThrCopyB_i4 smem_thr_copy_B_i4, 259 | const int num_params) { 260 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 261 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 262 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 263 | Tensor tCrB_i4_copy_view = smem_thr_copy_B_i4.retile_D(tCrB_i4); 264 | 265 | #pragma unroll 266 | for (int i = 0; i < size<2>(tCrA); ++i) { 267 | quant::load_params_Kchannel(tCrB_scales, tCrB_zeros, sK_params, threadIdx.x, i, num_params); 268 | } 269 | 270 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 271 | if (!B_in_regs) { 272 | cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, _0{}), tCrB_i4_copy_view(_, _, _0{})); 273 | quant::dequant_Kchannel_Vtensor(tCrB_i4(_,_,_0{}), tCrB_dequant(_,_,_0{}), tCrB_scales(_,_,_0{}), tCrB_zeros(_,_,_0{}), num_params); 274 | } 275 | 276 | #pragma unroll 277 | for (int i = 0; i < size<2>(tCrA); ++i) { 278 | if (i < size<2>(tCrA) - 1) { 279 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 280 | if (!B_in_regs) { 281 | cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, i + 1), tCrB_i4_copy_view(_, _, i + 1)); 282 | quant::dequant_Kchannel_Vtensor(tCrB_i4(_, _, i + 1), tCrB_dequant(_, _, i + 1), tCrB_scales(_, _, i + 1), tCrB_zeros(_, _, i + 1), num_params); 283 | } 284 | } 285 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB_dequant(_, _, i), acc); 286 | } 287 | } 288 | 289 | //////////////////////////////////////////////////////////////////////////////////////////////////// 290 | 291 | template 302 | __forceinline__ __device__ void gemm_Ktensor(Tensor0 &acc, Tensor1 &tCrA, 303 | Tensor2_i4 &tCrB_i4, Tensor2_dequant &tCrB_dequant, 304 | Tensor2_scales &tCrB_scales, Tensor2_zeros &tCrB_zeros, 305 | Tensor3 const& tCsA, 306 | Tensor4_i4 const& tCsB_i4, 307 | TiledMma tiled_mma, 308 | TiledCopyA smem_tiled_copy_A, 309 | TiledCopyB_i4 smem_tiled_copy_B_i4, 310 | ThrCopyA smem_thr_copy_A, 311 | ThrCopyB_i4 smem_thr_copy_B_i4, 312 | const int group_size) { 313 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 314 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 315 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 316 | Tensor tCrB_i4_copy_view = smem_thr_copy_B_i4.retile_D(tCrB_i4); 317 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 318 | if (!B_in_regs) { 319 | // cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, _0{}), tCrB_i4_copy_view(_, _, _0{})); 320 | quant::dequantize_Ktensor(tCrB_i4, tCrB_dequant, tCrB_scales, tCrB_zeros, 4, group_size, 0); 321 | } 322 | #pragma unroll 323 | for (int i = 0; i < size<2>(tCrA); ++i) { 324 | if (i < size<2>(tCrA) - 1) { 325 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 326 | if (!B_in_regs) { 327 | // cute::copy(smem_tiled_copy_B_i4, tCsB_i4(_, _, i + 1), tCrB_i4_copy_view(_, _, i + 1)); 328 | quant::dequantize_Ktensor(tCrB_i4, tCrB_dequant, tCrB_scales, tCrB_zeros, 4, group_size, i + 1); 329 | } 330 | } 331 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB_dequant(_, _, i), acc); 332 | } 333 | } 334 | 335 | //////////////////////////////////////////////////////////////////////////////////////////////////// 336 | 337 | template 341 | __forceinline__ __device__ void gemm_residual(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 342 | Tensor4 const& tCsB, TiledMma tiled_mma, 343 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 344 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 345 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 346 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 347 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 348 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 349 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 350 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 351 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 352 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 353 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } 354 | 355 | #pragma unroll 356 | for (int i = 0; i < size<2>(tCrA); ++i) { 357 | if (i < size<2>(tCrA) - 1) { 358 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 359 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } 360 | } 361 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 362 | } 363 | } 364 | 365 | //////////////////////////////////////////////////////////////////////////////////////////////////// 366 | 367 | template 369 | __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, 370 | TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, 371 | ThrCopy smem_thr_copy_B) { 372 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 373 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 374 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 375 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 376 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 377 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 378 | 379 | #pragma unroll 380 | for (int i = 0; i < size<2>(tCrA); ++i) { 381 | if (i < size<2>(tCrA) - 1) { 382 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 383 | } 384 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 385 | } 386 | } 387 | 388 | //////////////////////////////////////////////////////////////////////////////////////////////////// 389 | 390 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) 391 | template 392 | __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { 393 | static_assert(decltype(size<0>(acc_layout))::value == 4); 394 | static_assert(decltype(rank(acc_layout))::value == 3); 395 | auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) 396 | return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); 397 | }; 398 | 399 | //////////////////////////////////////////////////////////////////////////////////////////////////// 400 | 401 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) 402 | // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. 403 | template 404 | __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { 405 | using X = Underscore; 406 | static_assert(decltype(size<0>(acc_layout))::value == 4); 407 | static_assert(decltype(rank(acc_layout))::value == 3); 408 | constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); 409 | static_assert(mma_shape_K == 8 || mma_shape_K == 16); 410 | if constexpr (mma_shape_K == 8) { 411 | return acc_layout; 412 | } else { 413 | auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) 414 | return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); 415 | } 416 | }; 417 | 418 | //////////////////////////////////////////////////////////////////////////////////////////////////// 419 | 420 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) 421 | template 422 | __forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { 423 | using X = Underscore; 424 | static_assert(decltype(size<0>(acc_layout))::value == 4); 425 | static_assert(decltype(rank(acc_layout))::value == 3); 426 | auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) 427 | return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); 428 | }; 429 | 430 | //////////////////////////////////////////////////////////////////////////////////////////////////// 431 | 432 | template 433 | __forceinline__ __device__ auto convert_type(Tensor const &tensor) { 434 | using From_type = typename Engine::value_type; 435 | constexpr int numel = decltype(size(tensor))::value; 436 | cutlass::NumericArrayConverter convert_op; 437 | // HACK: this requires tensor to be "contiguous" 438 | auto frag = convert_op(*reinterpret_cast *>(tensor.data())); 439 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 440 | } 441 | 442 | //////////////////////////////////////////////////////////////////////////////////////////////////// 443 | 444 | template 445 | __forceinline__ __device__ void relu_(Tensor &tensor) { 446 | constexpr int numel = decltype(size(tensor))::value; 447 | static_assert(numel % 2 == 0); 448 | using value_t = typename Engine::value_type; 449 | // HACK: this requires tensor to be "contiguous" 450 | Tensor tensor_uint32 = recast(tensor); 451 | #pragma unroll 452 | for (int i = 0; i < size(tensor_uint32); ++i) { 453 | tensor_uint32(i) = relu2(tensor_uint32(i)); 454 | } 455 | } 456 | 457 | //////////////////////////////////////////////////////////////////////////////////////////////////// 458 | 459 | // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction 460 | template 461 | __forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { 462 | using From_type = typename Engine::value_type; 463 | static_assert(std::is_same_v || std::is_same_v); 464 | static_assert(std::is_same_v); 465 | constexpr int numel = decltype(size(tensor))::value; 466 | static_assert(numel % 2 == 0); 467 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 468 | // HACK: this requires tensor to be "contiguous" 469 | Tensor tensor_float2 = recast(tensor); 470 | Tensor out_uint32 = make_tensor(tensor_float2.layout()); 471 | #pragma unroll 472 | for (int i = 0; i < size(out_uint32); ++i) { 473 | out_uint32(i) = convert_relu2(tensor_float2(i)); 474 | } 475 | Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); 476 | #else 477 | Tensor out = flash::convert_type(tensor); 478 | flash::relu_(out); 479 | #endif 480 | return out; 481 | } 482 | 483 | //////////////////////////////////////////////////////////////////////////////////////////////////// 484 | 485 | // Blocks until all but N previous cp.async.commit_group operations have committed. 486 | // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all 487 | // (which is equivalent to commit_group then wait_group 0). 488 | // Instead we just call cp.async.wait_group 0, which is slightly faster. 489 | // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 490 | template 491 | CUTE_HOST_DEVICE 492 | void cp_async_wait() { 493 | #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) 494 | asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); 495 | #endif 496 | } 497 | 498 | //////////////////////////////////////////////////////////////////////////////////////////////////// 499 | 500 | template 503 | __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, 504 | Tensor &D, Tensor const &identity_MN, 505 | Tensor const &predicate_K, const int max_MN=0) { 506 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 507 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 508 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 509 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 510 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 511 | // There's no case where !Clear_OOB_K && Clear_OOB_MN 512 | static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); 513 | #pragma unroll 514 | for (int m = 0; m < size<1>(S); ++m) { 515 | if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { 516 | #pragma unroll 517 | for (int k = 0; k < size<2>(S); ++k) { 518 | if (Is_even_K || predicate_K(k)) { 519 | cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); 520 | } else if (Clear_OOB_K) { 521 | cute::clear(D(_, m, k)); 522 | } 523 | } 524 | } else if (Clear_OOB_MN) { 525 | cute::clear(D(_, m, _)); 526 | } 527 | } 528 | } 529 | 530 | //////////////////////////////////////////////////////////////////////////////////////////////////// 531 | 532 | template 535 | __forceinline__ __device__ void copy_w_min_idx(Tensor const &S, 536 | Tensor &D, Tensor const &identity_MN, 537 | Tensor const &predicate_K, 538 | const int max_MN=0, const int min_MN=0) { 539 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 540 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 541 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 542 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 543 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 544 | // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } 545 | #pragma unroll 546 | for (int m = 0; m < size<1>(S); ++m) { 547 | // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } 548 | if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { 549 | // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } 550 | #pragma unroll 551 | for (int k = 0; k < size<2>(S); ++k) { 552 | if (Is_even_K || predicate_K(k)) { 553 | cute::copy(S(_, m, k), D(_, m, k)); 554 | } 555 | } 556 | } 557 | } 558 | } 559 | 560 | //////////////////////////////////////////////////////////////////////////////////////////////////// 561 | 562 | } // namespace flash 563 | -------------------------------------------------------------------------------- /csrc/bit_decode/src/test_batch_packdecode.cu: -------------------------------------------------------------------------------- 1 | #include "flash_api.h" 2 | #include 3 | #include 4 | 5 | torch::Tensor single_mha(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, int head_dim) { 6 | const float sm_scale = 1.f / std::sqrt(float(head_dim)); 7 | auto scaled_q = q * sm_scale; 8 | 9 | auto scores = torch::einsum("bthd,bshd->bhts", {scaled_q, k}); 10 | auto attention = torch::softmax(scores, -1).to(v.dtype()); 11 | auto output = torch::einsum("bhts,bshd->bthd", {attention, v}); 12 | return output; 13 | } 14 | 15 | template 16 | std::tuple _generate_block_kvcache( 17 | int seqlen_k, 18 | int paged_kv_block_size, 19 | int batch_size, 20 | int nheads_k, 21 | int d, 22 | int num_bits, 23 | const std::string& quant_mode, 24 | torch::Device device, 25 | torch::ScalarType dtype) { 26 | 27 | // Calculate number of blocks needed 28 | int num_blocks = std::ceil(float(seqlen_k) / paged_kv_block_size) * batch_size; 29 | 30 | int num_per_params = 16 / num_bits; 31 | 32 | // Generate random k/v blocks 33 | auto k_cache_paged = torch::randn( 34 | {num_blocks, paged_kv_block_size, nheads_k, d}, 35 | torch::TensorOptions().device(device).dtype(dtype) 36 | ); 37 | 38 | auto v_cache_paged = torch::randn( 39 | {num_blocks, paged_kv_block_size, nheads_k, d}, 40 | torch::TensorOptions().device(device).dtype(dtype) 41 | ); 42 | 43 | // Pack 44 | torch::Tensor k_cache_paged_pack, v_cache_paged_pack; 45 | if (quant_mode == "k-channel") { 46 | k_cache_paged_pack = torch::randn( 47 | {num_blocks, paged_kv_block_size / num_per_params, nheads_k, d}, 48 | torch::TensorOptions().device(device) 49 | ).to(torch::kUInt16); 50 | 51 | v_cache_paged_pack = torch::randn( 52 | {num_blocks, paged_kv_block_size, nheads_k, d / num_per_params}, 53 | torch::TensorOptions().device(device) 54 | ).to(torch::kUInt16); 55 | 56 | } else { 57 | k_cache_paged_pack = torch::randn( 58 | {num_blocks, paged_kv_block_size, nheads_k, d / num_per_params}, 59 | torch::TensorOptions().device(device) 60 | ).to(torch::kUInt16); 61 | 62 | v_cache_paged_pack = torch::randn( 63 | {num_blocks, paged_kv_block_size, nheads_k, d / num_per_params}, 64 | torch::TensorOptions().device(device) 65 | ).to(torch::kUInt16); 66 | } 67 | 68 | // Generate block_table: for each batch, create a permutation of blocks 69 | // First create a randperm of all blocks 70 | auto block_table = torch::randperm(num_blocks, 71 | torch::TensorOptions().device(device).dtype(torch::kInt32) 72 | ); 73 | 74 | // Reshape to (batch_size, num_blocks_per_batch) 75 | int nblocks_per_batch = num_blocks / batch_size; 76 | block_table = block_table.reshape({batch_size, nblocks_per_batch}); 77 | 78 | return std::make_tuple(k_cache_paged, v_cache_paged, k_cache_paged_pack, v_cache_paged_pack, block_table, num_blocks); 79 | } 80 | 81 | template 82 | void TestDecodingKernelCorrectness(int bs, int seqlen_kv, const std::string& quant_mode, const int group_size) { 83 | // Set the random seed for reproducibility 84 | torch::manual_seed(42); 85 | 86 | const int seqlen_q = 1; 87 | const int page_block_size = 256; 88 | 89 | torch::Tensor Q_host = torch::rand({bs, seqlen_q, num_heads, head_dim}, torch::dtype(torch::kHalf)); 90 | torch::Tensor K_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 91 | torch::Tensor V_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 92 | 93 | torch::Tensor Q_device = Q_host.to(torch::kCUDA); 94 | torch::Tensor K_device = K_host.to(torch::kCUDA); 95 | torch::Tensor V_device = V_host.to(torch::kCUDA); 96 | 97 | // Page 98 | auto [k_cache_paged, v_cache_paged, k_cache_paged_pack, v_cache_paged_pack, block_table, num_blocks] = _generate_block_kvcache( 99 | seqlen_kv, 100 | page_block_size, 101 | bs, 102 | num_heads_kv, 103 | head_dim, 104 | num_bits, 105 | quant_mode, 106 | torch::kCUDA, 107 | torch::kHalf 108 | ); 109 | 110 | at::Tensor k_params = quant_mode == "k-channel" 111 | ? torch::empty({bs, seqlen_kv / group_size, num_heads_kv, head_dim}, torch::dtype(torch::kFloat32)).to(torch::kCUDA) 112 | : torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 113 | at::Tensor v_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 114 | 115 | auto cu_seqlens_k = torch::arange(0, (bs + 1) * seqlen_kv, seqlen_kv, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); 116 | std::optional opt_block_table = std::make_optional(block_table); 117 | 118 | kvcache_qpack( 119 | k_cache_paged, k_cache_paged_pack, k_params, 120 | v_cache_paged, v_cache_paged_pack, v_params, 121 | opt_block_table, 122 | cu_seqlens_k, 123 | seqlen_kv, 124 | quant_mode, 125 | group_size 126 | ); 127 | 128 | // mha_fwd_kvcache 129 | const float sm_scale = 1 / std::sqrt(float(head_dim)); 130 | auto out = mha_fwd_kvcache(Q_device, 131 | k_cache_paged_pack, k_params, 132 | v_cache_paged_pack, v_params, 133 | opt_block_table, 134 | sm_scale, 135 | quant_mode, 136 | group_size); 137 | 138 | torch::Tensor out_cpu = out.to(torch::kCPU); 139 | 140 | // torch reference 141 | // Page 142 | int nblocks_per_batch = block_table.size(1); 143 | auto flat_block_table = block_table.flatten().to(torch::kInt64); 144 | auto k_cache = k_cache_paged.index_select(0, flat_block_table); 145 | auto v_cache = v_cache_paged.index_select(0, flat_block_table); 146 | 147 | k_cache = k_cache.reshape({bs, nblocks_per_batch * page_block_size, num_heads_kv, head_dim}) 148 | .slice(1, 0, seqlen_kv); 149 | v_cache = v_cache.reshape({bs, nblocks_per_batch * page_block_size, num_heads_kv, head_dim}) 150 | .slice(1, 0, seqlen_kv); 151 | 152 | torch::Tensor out_ref = single_mha(Q_device, k_cache, v_cache, head_dim); 153 | out_ref = out_ref.to(torch::kCPU); 154 | 155 | // Compute the difference 156 | torch::Tensor diff = out_cpu - out_ref; 157 | float mean_absolute_error = diff.abs().mean().item(); 158 | float mean_squared_error = diff.pow(2).mean().item(); 159 | 160 | printf("batch_size: %d num_heads_kv: %d seqlen_kv: %d head_dim: %d Quant_mode: %s\n", bs, num_heads_kv, seqlen_kv, head_dim, quant_mode.c_str()); 161 | if (mean_absolute_error < 1e-2 && mean_squared_error < 1e-2) { 162 | printf("test pass ! \n"); 163 | printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); 164 | } else { 165 | printf("test fail ! \n"); 166 | printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); 167 | } 168 | 169 | printf("\nFirst 10 elements of out_cpu:\n"); 170 | auto out_cpu_accessor = out_cpu.flatten().data_ptr(); 171 | for (int i = 0; i < 10; i++) { 172 | printf("%.6f ", static_cast(out_cpu_accessor[i])); 173 | } 174 | 175 | printf("\n\nFirst 10 elements of out_ref:\n"); 176 | auto out_ref_accessor = out_ref.flatten().data_ptr(); 177 | for (int i = 0; i < 10; i++) { 178 | printf("%.6f ", static_cast(out_ref_accessor[i])); 179 | } 180 | 181 | printf("\n\n"); 182 | } 183 | 184 | int main() { 185 | const int batch_size = 4; 186 | const int num_heads = 32; 187 | const int num_heads_kv = 32; 188 | const int head_dim = 128; 189 | 190 | const std::string quant_mode = "k-channel"; 191 | const int num_bits = 4; 192 | const int group_size = 128; 193 | 194 | int seqlen_kv = 1024; 195 | 196 | TestDecodingKernelCorrectness(batch_size, seqlen_kv, quant_mode, group_size); 197 | 198 | return 0; 199 | } -------------------------------------------------------------------------------- /csrc/bit_decode/src/test_single_packdecode.cu: -------------------------------------------------------------------------------- 1 | #include "flash_api.h" 2 | #include 3 | #include 4 | 5 | torch::Tensor single_mha(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, int head_dim) { 6 | const float sm_scale = 1.f / std::sqrt(float(head_dim)); 7 | auto scaled_q = q * sm_scale; 8 | 9 | auto scores = torch::einsum("bthd,bshd->bhts", {scaled_q, k}); 10 | auto attention = torch::softmax(scores, -1).to(v.dtype()); 11 | auto output = torch::einsum("bhts,bshd->bthd", {attention, v}); 12 | return output; 13 | } 14 | 15 | 16 | template 17 | void TestDecodingKernelCorrectness(int seqlen_kv, const std::string& quant_mode, const int group_size) { 18 | // Set the random seed for reproducibility 19 | torch::manual_seed(42); 20 | 21 | const int bs = 1; 22 | const int seqlen_q = 1; 23 | const int pack_nums = 16 / num_bits; 24 | 25 | torch::Tensor Q_host = torch::rand({bs, seqlen_q, num_heads, head_dim}, torch::dtype(torch::kHalf)); 26 | torch::Tensor K_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 27 | torch::Tensor V_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); 28 | 29 | torch::Tensor Q_device = Q_host.to(torch::kCUDA); 30 | torch::Tensor K_device = K_host.to(torch::kCUDA); 31 | torch::Tensor V_device = V_host.to(torch::kCUDA); 32 | 33 | at::Tensor k_pack, k_params, v_pack, v_params; 34 | if (quant_mode == "k-channel") { 35 | k_pack = torch::empty({bs, seqlen_kv / pack_nums, num_heads_kv, head_dim}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 36 | k_params = torch::empty({bs, seqlen_kv / group_size, num_heads_kv, head_dim}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 37 | } else { 38 | k_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 39 | k_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 40 | } 41 | 42 | v_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); 43 | v_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); 44 | 45 | // Convert K, V to unpadded format 46 | torch::Tensor K_unpad = K_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); 47 | torch::Tensor V_unpad = V_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); 48 | 49 | auto cu_seqlens_k = torch::arange(0, (bs + 1) * seqlen_kv, seqlen_kv, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); 50 | std::optional opt_block_table = std::nullopt; 51 | 52 | kvcache_qpack( 53 | K_unpad, k_pack, k_params, 54 | V_unpad, v_pack, v_params, 55 | opt_block_table, 56 | cu_seqlens_k, 57 | seqlen_kv, 58 | quant_mode, 59 | group_size 60 | ); 61 | 62 | // mha_fwd_kvcache 63 | const float sm_scale = 1 / std::sqrt(float(head_dim)); 64 | auto out = mha_fwd_kvcache(Q_device, 65 | k_pack, k_params, 66 | v_pack, v_params, 67 | opt_block_table, 68 | sm_scale, 69 | quant_mode, 70 | group_size); 71 | 72 | torch::Tensor out_cpu = out.to(torch::kCPU); 73 | 74 | // CPU reference 75 | torch::Tensor out_ref = single_mha(Q_host, K_host, V_host, head_dim); 76 | 77 | // Compute the difference 78 | torch::Tensor diff = out_cpu - out_ref; 79 | float mean_absolute_error = diff.abs().mean().item(); 80 | float mean_squared_error = diff.pow(2).mean().item(); 81 | 82 | printf("\nnum_bits: %d num_heads_kv: %d seqlen_kv: %d head_dim: %d Quant_mode: %s, Group_size: %d\n", num_bits, num_heads_kv, seqlen_kv, head_dim, quant_mode.c_str(), group_size); 83 | if (mean_absolute_error < 1e-1 && mean_squared_error < 1e-1) { 84 | printf("test pass ! \n"); 85 | printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); 86 | } else { 87 | printf("test fail ! \n"); 88 | printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); 89 | } 90 | 91 | printf("\nFirst 10 elements of out_cpu:\n"); 92 | auto out_cpu_accessor = out_cpu.flatten().data_ptr(); 93 | for (int i = 0; i < 10; i++) { 94 | printf("%.6f ", static_cast(out_cpu_accessor[i])); 95 | } 96 | 97 | printf("\n\nFirst 10 elements of out_ref:\n"); 98 | auto out_ref_accessor = out_ref.flatten().data_ptr(); 99 | for (int i = 0; i < 10; i++) { 100 | printf("%.6f ", static_cast(out_ref_accessor[i])); 101 | } 102 | 103 | printf("\n\n"); 104 | } 105 | 106 | 107 | int main() { 108 | const int num_heads = 32; 109 | const int num_heads_kv = 32; 110 | const int head_dim = 128; 111 | 112 | const std::string quant_mode = "k-channel"; 113 | const int num_bits = 4; 114 | const int group_size = 128; 115 | 116 | int seqlen_kv = 1024; 117 | 118 | TestDecodingKernelCorrectness(seqlen_kv, quant_mode, group_size); 119 | 120 | return 0; 121 | } -------------------------------------------------------------------------------- /imgs/4090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DD-DuDa/BitDecoding/30f0f8181e299e1343a97cefb53439249ecd3012/imgs/4090.png -------------------------------------------------------------------------------- /imgs/a100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DD-DuDa/BitDecoding/30f0f8181e299e1343a97cefb53439249ecd3012/imgs/a100.png -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DD-DuDa/BitDecoding/30f0f8181e299e1343a97cefb53439249ecd3012/imgs/overview.png -------------------------------------------------------------------------------- /imgs/scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DD-DuDa/BitDecoding/30f0f8181e299e1343a97cefb53439249ecd3012/imgs/scheme.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | matplotlib 4 | pandas 5 | packaging 6 | ninja 7 | flash-attn -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | import os 4 | import re 5 | import ast 6 | from pathlib import Path 7 | from packaging.version import parse, Version 8 | import platform 9 | 10 | from setuptools import setup, find_packages 11 | import subprocess 12 | 13 | import urllib.request 14 | import urllib.error 15 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 16 | 17 | import torch 18 | from torch.utils.cpp_extension import ( 19 | BuildExtension, 20 | CppExtension, 21 | CUDAExtension, 22 | CUDA_HOME, 23 | ) 24 | 25 | with open("README.md", "r", encoding="utf-8") as fh: 26 | long_description = fh.read() 27 | 28 | # ninja build does not work unless include_dirs are abs path 29 | this_dir = os.path.dirname(os.path.abspath(__file__)) 30 | 31 | PACKAGE_NAME = "bit_decode" 32 | 33 | BASE_WHEEL_URL = ( 34 | "TODO" 35 | ) 36 | 37 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 38 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 39 | FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" 40 | SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 41 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 42 | FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" 43 | 44 | def get_platform(): 45 | """ 46 | Returns the platform name as used in wheel filenames. 47 | """ 48 | if sys.platform.startswith("linux"): 49 | return "linux_x86_64" 50 | elif sys.platform == "darwin": 51 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 52 | return f"macosx_{mac_version}_x86_64" 53 | elif sys.platform == "win32": 54 | return "win_amd64" 55 | else: 56 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 57 | 58 | def get_cuda_bare_metal_version(cuda_dir): 59 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 60 | output = raw_output.split() 61 | release_idx = output.index("release") + 1 62 | bare_metal_version = parse(output[release_idx].split(",")[0]) 63 | 64 | return raw_output, bare_metal_version 65 | 66 | def check_if_cuda_home_none(global_option: str) -> None: 67 | if CUDA_HOME is not None: 68 | return 69 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 70 | # in that case. 71 | warnings.warn( 72 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 73 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 74 | "only images whose names contain 'devel' will provide nvcc." 75 | ) 76 | 77 | def append_nvcc_threads(nvcc_extra_args): 78 | return nvcc_extra_args + ["--threads", "4"] 79 | 80 | cmdclass = {} 81 | ext_modules = [] 82 | 83 | # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp 84 | # files included in the source distribution, in case the user compiles from source. 85 | # subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) 86 | 87 | if not SKIP_CUDA_BUILD: 88 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 89 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 90 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 91 | 92 | # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h 93 | # See https://github.com/pytorch/pytorch/pull/70650 94 | generator_flag = [] 95 | torch_dir = torch.__path__[0] 96 | if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): 97 | generator_flag = ["-DOLD_GENERATOR_PATH"] 98 | 99 | check_if_cuda_home_none("bit_decode") 100 | 101 | # Check, if CUDA11 is installed for compute capability 8.0 102 | cc_flag = [] 103 | if CUDA_HOME is not None: 104 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 105 | if bare_metal_version < Version("11.6"): 106 | raise RuntimeError( 107 | "FlashAttention is only supported on CUDA 11.6 and above. " 108 | "Note: make sure nvcc has a supported version by running nvcc -V." 109 | ) 110 | 111 | cc_flag.append("-gencode") 112 | cc_flag.append("arch=compute_80,code=sm_80") 113 | 114 | if CUDA_HOME is not None: 115 | if bare_metal_version >= Version("11.8"): 116 | cc_flag.append("-gencode") 117 | cc_flag.append("arch=compute_90,code=sm_90") 118 | 119 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 120 | # torch._C._GLIBCXX_USE_CXX11_ABI 121 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 122 | 123 | if FORCE_CXX11_ABI: 124 | torch._C._GLIBCXX_USE_CXX11_ABI = True 125 | 126 | ext_modules.append( 127 | CUDAExtension( 128 | name="bit_decode_cuda", 129 | sources=[ 130 | "csrc/bit_decode/decode_api.cpp", 131 | "csrc/bit_decode/src/genfile/flash_fwd_hdim128_fp16_sm80.cu", 132 | "csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu", 133 | "csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu", 134 | "csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu", 135 | "csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu", 136 | ], 137 | extra_compile_args={ 138 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 139 | "nvcc": append_nvcc_threads( 140 | [ 141 | "-O3", 142 | "-std=c++17", 143 | "-U__CUDA_NO_HALF_OPERATORS__", 144 | "-U__CUDA_NO_HALF_CONVERSIONS__", 145 | "-U__CUDA_NO_HALF2_OPERATORS__", 146 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 147 | "--expt-relaxed-constexpr", 148 | "--expt-extended-lambda", 149 | "--use_fast_math", 150 | # "--ptxas-options=-v", 151 | # "--ptxas-options=-O2", 152 | # "-lineinfo", 153 | ] 154 | + generator_flag 155 | + cc_flag 156 | ), 157 | }, 158 | extra_link_args=['-Wl,-rpath,{}'.format(os.path.join(torch.__path__[0], 'lib'))], 159 | include_dirs=[ 160 | Path(this_dir) / "csrc" / "bit_decode", 161 | Path(this_dir) / "csrc" / "bit_decode" / "src", 162 | Path(this_dir) / "libs" / "cutlass" / "include", 163 | ], 164 | ) 165 | ) 166 | 167 | 168 | def get_package_version(): 169 | with open(Path(this_dir) / "bit_decode" / "__init__.py", "r") as f: 170 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 171 | public_version = ast.literal_eval(version_match.group(1)) 172 | local_version = os.environ.get("BIT_DECODE_LOCAL_VERSION") 173 | if local_version: 174 | return f"{public_version}+{local_version}" 175 | else: 176 | return str(public_version) 177 | 178 | def get_wheel_url(): 179 | # Determine the version numbers that will be used to determine the correct wheel 180 | # We're using the CUDA version used to build torch, not the one currently installed 181 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 182 | torch_cuda_version = parse(torch.version.cuda) 183 | torch_version_raw = parse(torch.__version__) 184 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 185 | # to save CI time. Minor versions should be compatible. 186 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 187 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 188 | platform_name = get_platform() 189 | flash_version = get_package_version() 190 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 191 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 192 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 193 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 194 | 195 | # Determine wheel URL based on CUDA version, torch version, python version and OS 196 | wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 197 | wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) 198 | return wheel_url, wheel_filename 199 | 200 | class CachedWheelsCommand(_bdist_wheel): 201 | """ 202 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 203 | find an existing wheel (which is currently the case for all flash attention installs). We use 204 | the environment parameters to detect whether there is already a pre-built version of a compatible 205 | wheel available and short-circuits the standard full build pipeline. 206 | """ 207 | 208 | def run(self): 209 | super().run() 210 | # if FORCE_BUILD: 211 | # return super().run() 212 | 213 | # wheel_url, wheel_filename = get_wheel_url() 214 | # print("Guessing wheel URL: ", wheel_url) 215 | # try: 216 | # urllib.request.urlretrieve(wheel_url, wheel_filename) 217 | 218 | # # Make the archive 219 | # # Lifted from the root wheel processing command 220 | # # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 221 | # if not os.path.exists(self.dist_dir): 222 | # os.makedirs(self.dist_dir) 223 | 224 | # impl_tag, abi_tag, plat_tag = self.get_tag() 225 | # archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 226 | 227 | # wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 228 | # print("Raw wheel path", wheel_path) 229 | # os.rename(wheel_filename, wheel_path) 230 | # except urllib.error.HTTPError: 231 | # print("Precompiled wheel not found. Building from source...") 232 | # # If the wheel could not be downloaded, build from source 233 | # super().run() 234 | 235 | 236 | setup( 237 | name=PACKAGE_NAME, 238 | version=get_package_version(), 239 | packages=find_packages( 240 | exclude=( 241 | "build", 242 | "csrc", 243 | "include", 244 | "tests", 245 | "dist", 246 | "docs", 247 | "benchmarks", 248 | "bit_decode.egg-info", 249 | ) 250 | ), 251 | author="Dayou Du", 252 | author_email="duda200054@gmail.com", 253 | description="BitDecoding", 254 | long_description=long_description, 255 | long_description_content_type="text/markdown", 256 | url="https://github.com/Dao-AILab/flash-attention", 257 | classifiers=[ 258 | "Programming Language :: Python :: 3", 259 | "License :: OSI Approved :: BSD License", 260 | "Operating System :: Unix", 261 | ], 262 | ext_modules=ext_modules, 263 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 264 | if ext_modules 265 | else { 266 | "bdist_wheel": CachedWheelsCommand, 267 | }, 268 | python_requires=">=3.7", 269 | setup_requires=["ninja"], 270 | install_requires=[ 271 | "torch", 272 | "einops", 273 | "packaging" 274 | ], 275 | ) --------------------------------------------------------------------------------