├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── deep_gemm ├── __init__.py ├── include │ └── deep_gemm │ │ ├── fp8_gemm.cuh │ │ ├── mma_utils.cuh │ │ ├── nvrtc_std.cuh │ │ ├── scheduler.cuh │ │ ├── tma_utils.cuh │ │ └── utils.cuh ├── jit │ ├── __init__.py │ ├── compiler.py │ ├── interleave_ffma.py │ └── runtime.py ├── jit_kernels │ ├── __init__.py │ ├── gemm.py │ ├── m_grouped_gemm.py │ ├── runtime.py │ ├── tuner.py │ └── utils.py └── utils.py ├── figures └── design.png ├── indexing └── main.cu ├── setup.py └── tests ├── test_core.py └── test_jit.py /.gitignore: -------------------------------------------------------------------------------- 1 | cmake-build-* 2 | .idea 3 | .DS_Store 4 | build 5 | dist 6 | *.egg-info 7 | *.pyc 8 | 9 | # Third-party links created by `setup.py develop` 10 | deep_gemm/include/cute 11 | deep_gemm/include/cutlass 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/cutlass"] 2 | path = third-party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT 2 | # TODO: add CUDA utils' library via CMake 3 | cmake_minimum_required(VERSION 3.10) 4 | project(deep_gemm LANGUAGES CXX CUDA) 5 | 6 | set(CMAKE_CXX_STANDARD 20) 7 | set(CMAKE_CUDA_STANDARD 20) 8 | set(CMAKE_VERBOSE_MAKEFILE ON) 9 | 10 | find_package(CUDAToolkit REQUIRED) 11 | find_package(pybind11 REQUIRED) 12 | 13 | file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }") 14 | execute_process( 15 | COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu 16 | RESULT_VARIABLE NVCC_RESULT 17 | OUTPUT_VARIABLE NVCC_OUTPUT 18 | ERROR_VARIABLE NVCC_ERROR_OUTPUT 19 | WORKING_DIRECTORY ${CMAKE_BINARY_DIR} 20 | ) 21 | 22 | if (NVCC_RESULT EQUAL "0") 23 | set(NVCC_SUPPORTS_SM90 TRUE) 24 | message(STATUS "NVCC supports SM90") 25 | else() 26 | message(STATUS "NVCC does not support SM90") 27 | endif() 28 | 29 | if (NVCC_SUPPORTS_SM90) 30 | set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE) 31 | list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a") 32 | endif() 33 | find_package(Torch REQUIRED) 34 | 35 | include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include) 36 | include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) 37 | link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib) 38 | 39 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") 40 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") 41 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG") 42 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10") 43 | 44 | cuda_add_library(example_gemm STATIC indexing/main.cu) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGEMM 2 | 3 | DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. 4 | 5 | Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. 6 | 7 | Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. 8 | 9 | ## News 10 | 11 | - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). 12 | - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. 13 | 14 | ## Roadmap 15 | 16 | - [x] More correctness tests for grouped-contiguous layout 17 | - [x] Shared memory swizzling for output 18 | - [ ] Larger block size on N (up to 256) 19 | - [x] MoE scheduler with TMA multicast compatibility 20 | - [x] Fix TMA multicast compatibility for indivisible shapes 21 | - [ ] Skip useless computation on M 22 | - [x] NVRTC as a faster compiler 23 | - [ ] Stolen JIT cache 24 | - [ ] Sanitizer for testing 25 | - [ ] Weight gradient kernels for dense models 26 | - [ ] Weight gradient kernels for MoE models 27 | - [ ] Utility kernels for MoE models (as a pre-built CUDA library) 28 | - [ ] CUDA PDL support 29 | - [ ] More scaling granularity support via templates 30 | - [ ] Larger TMA multicast size for some shapes 31 | - [x] MMA template refactor with CUTLASS 32 | - [ ] Optimizations for unaligned shapes 33 | - [ ] Optimizations for power efficiency 34 | - [ ] Remove shape limitations on N and K 35 | - [ ] BF16 kernels 36 | - [ ] Split/stream-k optimizations 37 | 38 | ## Quick start 39 | 40 | ### Requirements 41 | 42 | - Hopper architecture GPUs, `sm_90a` must be supported 43 | - Python 3.8 or above 44 | - CUDA 12.3 or above 45 | - **But we highly recommend 12.8 or above for the best performance** 46 | - PyTorch 2.1 or above 47 | - CUTLASS 3.6 or above (could be cloned by Git submodule) 48 | 49 | ### Development 50 | 51 | ```bash 52 | # Submodule must be cloned 53 | git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git 54 | 55 | # Make symbolic links for third-party (CUTLASS and CuTe) include directories 56 | python setup.py develop 57 | 58 | # Test JIT compilation 59 | python tests/test_jit.py 60 | 61 | # Test all GEMM implements (normal, contiguous-grouped and masked-grouped) 62 | python tests/test_core.py 63 | ``` 64 | 65 | ### Installation 66 | 67 | ```bash 68 | python setup.py install 69 | ``` 70 | 71 | Then, import `deep_gemm` in your Python project, and enjoy! 72 | 73 | ## Interfaces 74 | 75 | #### Notices 76 | 77 | This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. 78 | 79 | #### Normal dense GEMMs (non-grouped) 80 | 81 | To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation. 82 | 83 | #### Grouped GEMMs (contiguous layout) 84 | 85 | Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. 86 | 87 | For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`). 88 | 89 | For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation. 90 | 91 | #### Grouped GEMMs (masked layout) 92 | 93 | During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. 94 | 95 | Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. 96 | 97 | #### Utilities 98 | 99 | The library provides some utility functions besides the above kernels: 100 | 101 | - `deep_gemm.set_num_sms`: set the maximum SM count to use 102 | - `deep_gemm.get_num_sms`: get the current SM maximum count 103 | - `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout 104 | - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size 105 | - `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor 106 | 107 | The library also provides some environment variables, which may be useful: 108 | 109 | - General 110 | - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default 111 | - JIT cache related 112 | - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default 113 | - `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default 114 | - NVCC/NVRTC selections 115 | - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default 116 | - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default 117 | - Compiler options 118 | - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default 119 | - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default 120 | - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default 121 | - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default 122 | - Post optimization 123 | - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default 124 | - Heuristic selection 125 | - `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default 126 | - Testing 127 | - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default 128 | 129 | For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. 130 | 131 | ## Optimizations 132 | 133 | We indicate the techniques excluded from CUTLASS with 🐳. 134 | 135 | #### Persistent warp-specialization 136 | 137 | Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below: 138 | 139 | ![design](figures/design.png) 140 | 141 | #### Hopper TMA features 142 | 143 | The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for: 144 | 145 | - TMA load for LHS, LHS scaling factors, and RHS matrices 146 | - TMA store for the output matrix 147 | - TMA multicast (automatically decide LHS or RHS to broadcast) 148 | - TMA descriptor prefetching 149 | 150 | #### Common detail optimizations 151 | 152 | - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction 153 | - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups 154 | - Less bank conflicts via 3D TMA or swizzling 155 | - Larger block sizes (up to 256x128 🐳) 156 | - Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 157 | 158 | #### A unified and optimized block scheduler 159 | 160 | - [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels 161 | - [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse 162 | 163 | #### Fully JIT design 🐳 164 | 165 | DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages: 166 | 167 | - GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants 168 | - Saving registers 169 | - Compilers may do more optimizations 170 | - Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size 171 | - But without auto-tuning, the optimal one is deterministically selected 172 | - Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities 173 | - Very important for small shapes 174 | - Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details 175 | 176 | Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler. 177 | 178 | #### Unaligned block sizes 🐳 179 | 180 | For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains. 181 | 182 | #### FFMA SASS interleaving 🐳 183 | 184 | We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern. 185 | After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/96a9f72baf00f40b9b299653fcef8d3e2b4a3d49/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work). 186 | 187 | To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions. 188 | 189 | ## Acknowledgement 190 | 191 | DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! 192 | 193 | ## License 194 | 195 | This code repository is released under [the MIT License](LICENSE). 196 | 197 | ## Citation 198 | 199 | ```bibtex 200 | @misc{deepgemm2025, 201 | title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling}, 202 | author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu}, 203 | year={2025}, 204 | publisher = {GitHub}, 205 | howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, 206 | } 207 | ``` 208 | -------------------------------------------------------------------------------- /deep_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import jit 4 | from .jit_kernels import ( 5 | gemm_fp8_fp8_bf16_nt, 6 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 7 | m_grouped_gemm_fp8_fp8_bf16_nt_masked, 8 | ceil_div, 9 | set_num_sms, get_num_sms, 10 | get_col_major_tma_aligned_tensor, 11 | get_m_alignment_for_contiguous_layout 12 | ) 13 | from .utils import bench, bench_kineto, calc_diff 14 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/fp8_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #pragma clang diagnostic push 3 | #pragma clang diagnostic ignored "-Wunknown-attributes" 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "mma_utils.cuh" 13 | #include "scheduler.cuh" 14 | #include "tma_utils.cuh" 15 | #include "utils.cuh" 16 | 17 | namespace deep_gemm { 18 | 19 | enum class Layout { 20 | RowMajor, 21 | ColMajor 22 | }; 23 | 24 | __device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { 25 | return block_m == 64 ? 1 : 2; 26 | } 27 | 28 | template 29 | __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { 30 | DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); 31 | return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; 32 | } 33 | 34 | template 35 | __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { 36 | if (num_former_iters == kNumFormerIters) { 37 | inner_launch_k_iterations(func, cute::Int{}); 38 | return; 39 | } 40 | 41 | if constexpr (kNumFormerIters + kGap <= kEnd) 42 | outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); 43 | } 44 | 45 | template 53 | __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) 54 | fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, 55 | uint32_t shape_m, 56 | const __grid_constant__ CUtensorMap tensor_map_a, 57 | const __grid_constant__ CUtensorMap tensor_map_b, 58 | const __grid_constant__ CUtensorMap tensor_map_scales_a, 59 | const __grid_constant__ CUtensorMap tensor_map_d) { 60 | #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) 61 | // Scaling checks 62 | DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); 63 | DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); 64 | 65 | // Types 66 | using WGMMA = typename FP8MMASelector::type; 67 | using Barrier = cutlass::arch::ClusterTransactionBarrier; 68 | DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); 69 | 70 | // Shared memory 71 | static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); 72 | static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); 73 | static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); 74 | static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); 75 | static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); 76 | static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); 77 | static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); 78 | 79 | // Configs 80 | constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; 81 | constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); 82 | constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; 83 | constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); 84 | const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); 85 | const uint32_t lane_idx = get_lane_id(); 86 | 87 | // Prefetch TMA descriptors at the very beginning 88 | if (threadIdx.x == kNumMathThreads) { 89 | // NOTES: `reinterpret_cast` must be here, or NVRTC will fail 90 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); 91 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); 92 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); 93 | 94 | // `tensor_map_d` is only used in swizzling mode 95 | // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode 96 | if constexpr (kSwizzleDMode > 0) 97 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); 98 | } 99 | __syncwarp(); 100 | 101 | // Align to 1024 bytes for swizzle-128B 102 | extern __shared__ __align__(1024) uint8_t smem_buffer[]; 103 | DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); 104 | 105 | // Data on shared memory 106 | auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); 107 | __nv_fp8_e4m3* smem_a[kNumStages]; 108 | __nv_fp8_e4m3* smem_b[kNumStages]; 109 | float* smem_scales_a[kNumStages]; 110 | float* smem_scales_b; 111 | 112 | // TMA Barrier for both divisible and non-divisible cases 113 | Barrier* full_barriers[kNumStages]; 114 | Barrier* empty_barriers[kNumStages]; 115 | 116 | // Fill shared memory pointers 117 | #pragma unroll 118 | for (int i = 0; i < kNumStages; ++ i) { 119 | smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); 120 | smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); 121 | smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); 122 | } 123 | smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); 124 | 125 | // Fill barriers 126 | auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); 127 | #pragma unroll 128 | for (int i = 0; i < kNumStages; ++ i) { 129 | full_barriers[i] = barrier_start_ptr + i; 130 | empty_barriers[i] = barrier_start_ptr + kNumStages + i; 131 | } 132 | 133 | // Initialize barriers 134 | DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); 135 | if (threadIdx.x == kNumMathThreads) { 136 | // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, 137 | // even with TMA multicast disabled, we want to make the behavior aligned 138 | #pragma unroll 139 | for (int i = 0; i < kNumStages; ++ i) { 140 | full_barriers[i]->init(1); 141 | empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); 142 | } 143 | 144 | // Make initialized barrier visible in async proxy 145 | cutlass::arch::fence_view_async_shared(); 146 | (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); 147 | } 148 | 149 | // Synchronize all threads to make barrier visible in normal memory model 150 | (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); 151 | 152 | // For pipeline unrolling 153 | struct DivisibleK {}; 154 | struct NotDivisibleK {}; 155 | auto launch_k_iterations = [](const auto& func, int num_former_iters) { 156 | constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; 157 | constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; 158 | constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; 159 | 160 | // NOTES: for too-many branches (> 5), we disable this optimization 161 | // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value 162 | outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { 163 | if constexpr (SHAPE_K % kFullKOfAllStages == 0) { 164 | for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) 165 | func(k_iter, DivisibleK{}, num_former_iters_type); 166 | } else { 167 | for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) 168 | func(k_iter, DivisibleK{}, num_former_iters_type); 169 | func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); 170 | } 171 | }, func, kShouldOptimize ? num_former_iters : 0); 172 | }; 173 | 174 | // Register reconfigurations 175 | constexpr int kNumTMARegisters = 40; 176 | constexpr int kNumMathRegisters = 232; 177 | 178 | // Block scheduler 179 | uint32_t m_block_idx, n_block_idx; 180 | auto scheduler = Scheduler(shape_m, grouped_layout); 181 | 182 | if (threadIdx.x >= kNumMathThreads) { 183 | // TMA warp-group for loading data 184 | cutlass::arch::warpgroup_reg_dealloc(); 185 | 186 | // NOTES: only one thread (or warp) will be used 187 | if (threadIdx.x == kNumMathThreads) { 188 | // Persistently schedule over blocks 189 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 190 | launch_k_iterations([&](int k_iter, auto type, auto _) { 191 | constexpr bool kHasDivisibleStages = std::is_same_v; 192 | constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; 193 | DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); 194 | 195 | // Assign TMA multicast number into A and B 196 | // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. 197 | const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); 198 | const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 199 | const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 200 | DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); 201 | 202 | // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all 203 | // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant 204 | #pragma unroll 205 | for (uint32_t s = 0; s < kNumInnerStages; ++ s) { 206 | // Wait consumer release 207 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); 208 | 209 | // Issue TMA A 210 | auto& full_barrier = *full_barriers[s]; 211 | int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; 212 | tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), 213 | smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), 214 | num_tma_multicast_a); 215 | tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), 216 | smem_scales_a[s], m_block_idx * BLOCK_M, 217 | scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), 218 | num_tma_multicast_a); 219 | 220 | // Issue TMA B 221 | tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), 222 | smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), 223 | num_tma_multicast_b); 224 | full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); 225 | } 226 | 227 | // Wait unaligned cases 228 | #pragma unroll 229 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 230 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); 231 | full_barriers[s]->arrive(); 232 | } 233 | }, 0); 234 | } 235 | 236 | // To safely deconstruct distributed shared barriers, we need another round of empty waits 237 | if constexpr (kNumTMAMulticast > 1) { 238 | #pragma unroll 239 | for (uint32_t s = 0; s < kNumStages; ++ s) 240 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); 241 | } 242 | } 243 | } else { 244 | // Math warp-groups for WGMMA 245 | cutlass::arch::warpgroup_reg_alloc(); 246 | 247 | // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers 248 | const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); 249 | const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; 250 | 251 | // Persistently schedule over blocks 252 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 253 | // Decide the number of scales B to load 254 | DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); 255 | uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; 256 | if constexpr (not kMustUseUniformedScaleB) { 257 | num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; 258 | num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; 259 | } 260 | uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); 261 | 262 | // Load B scales with math warp-groups 263 | // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks 264 | if (threadIdx.x >= 32) { 265 | auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); 266 | auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; 267 | #pragma unroll 268 | for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) 269 | st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); 270 | } 271 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 272 | 273 | // Accumulation for WGMMA or CUDA promotion 274 | constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); 275 | DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); 276 | float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; 277 | 278 | // Empty barrier arrival 279 | auto empty_barrier_arrive = [&](int s) { 280 | if constexpr (kNumTMAMulticast == 1) { 281 | lane_idx == 0 ? empty_barriers[s]->arrive() : void(); 282 | } else { 283 | auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); 284 | lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); 285 | } 286 | }; 287 | 288 | // Launch MMAs 289 | launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { 290 | constexpr bool kHasDivisibleStages = std::is_same_v; 291 | constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; 292 | DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); 293 | 294 | #pragma unroll 295 | for (int s = 0; s < kNumInnerStages; ++ s) { 296 | // Read B scales 297 | float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; 298 | // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks 299 | if constexpr (not kMustUseUniformedScaleB) 300 | scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); 301 | 302 | // Wait TMA arrivals 303 | full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); 304 | 305 | // TODO: remove some useless computation for unaligned Ms 306 | #pragma unroll 307 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 308 | auto m_offset = local_idx * WAVE_BLOCK_M; 309 | 310 | // Read A scales 311 | // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results 312 | auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); 313 | auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); 314 | 315 | // Commit WGMMA instructions 316 | #pragma unroll 317 | for (int i = 0; i < WGMMA::kNumAccum; ++ i) 318 | warpgroup_fence_operand(accum[i]); 319 | warpgroup_arrive(); 320 | #pragma unroll 321 | for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { 322 | auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); 323 | auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); 324 | WGMMA::wgmma(desc_a, desc_b, accum, k); 325 | } 326 | warpgroup_commit_batch(); 327 | #pragma unroll 328 | for (int i = 0; i < WGMMA::kNumAccum; ++ i) 329 | warpgroup_fence_operand(accum[i]); 330 | warpgroup_wait<0>(); 331 | 332 | // Notify barrier arrival at the last warpgroup wave 333 | if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) 334 | empty_barrier_arrive(s); 335 | 336 | // Promote with scales 337 | // NOTES: making it as predicates is very important for performance, comparing to two loops 338 | float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; 339 | float scale_0_1, scale_1_1; 340 | if constexpr (not kMustUseUniformedScaleB) 341 | scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; 342 | 343 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 344 | #pragma unroll 345 | for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 346 | // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant 347 | bool predicate = kMustUseUniformedScaleB or i < num_former_iters; 348 | shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; 349 | shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; 350 | shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; 351 | shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; 352 | } 353 | } 354 | } 355 | 356 | // Wait unaligned cases 357 | #pragma unroll 358 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 359 | full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); 360 | empty_barrier_arrive(s); 361 | } 362 | }, num_former_iters); 363 | 364 | // TMA checks 365 | constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); 366 | constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); 367 | constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; 368 | DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); 369 | DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, 370 | "Unaligned TMA store or too many TMA store instructions"); 371 | DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); 372 | DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, 373 | "Swizzling and padding are not compatible"); 374 | 375 | // Wait last TMA store to be finished 376 | if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) 377 | cute::tma_store_wait<0>(); 378 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 379 | 380 | // Write back to shared memory using STSM and issue TMA stores 381 | DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); 382 | #pragma unroll 383 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 384 | auto m_offset = local_idx * WAVE_BLOCK_M; 385 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 386 | #pragma unroll 387 | for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 388 | // Swizzle or padding into the correct address 389 | uint8_t* smem_ptr = nullptr; 390 | if constexpr (kSwizzleDMode > 0) { 391 | // Calculate the swizzling atom offset and in-atom offset 392 | constexpr int kNumBankGroupBytes = 16; 393 | auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); 394 | 395 | // Calculate the index of the bank group to be written in the atom 396 | auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); 397 | 398 | // Reshape the atom in another view and swizzle 399 | // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` 400 | // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` 401 | constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; 402 | auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); 403 | auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); 404 | col ^= row % (kSwizzleDMode / 16); 405 | 406 | // Add back into the base pointer 407 | // NOTES: think twice before modifying this, as changes may affect the number of instructions 408 | smem_ptr = reinterpret_cast(smem_d) + // Base pointer 409 | warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset 410 | m_offset * kSwizzleDMode + // Wave offset 411 | atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) 412 | row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset 413 | } else { 414 | // No swizzling, just padding 415 | // NOTES: padding must be zero for BF16 output 416 | DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); 417 | smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); 418 | } 419 | 420 | // NOTES: only 16 lanes' addresses are used 421 | SM90_U32x2_STSM_N::copy( 422 | __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), 423 | __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), 424 | smem_ptr 425 | ); 426 | } 427 | } 428 | cute::tma_store_fence(); 429 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 430 | 431 | // Use TMA store to write back to global memory 432 | // TODO: compatible with FP32 output 433 | DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); 434 | if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { 435 | auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; 436 | auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; 437 | cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, 438 | n_block_idx * BLOCK_N + in_block_n_offset, 439 | scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); 440 | cute::tma_store_arrive(); 441 | } 442 | __syncwarp(); 443 | } 444 | } 445 | #else 446 | if (blockIdx.x == 0 and threadIdx.x == 0) 447 | DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); 448 | #endif 449 | } 450 | 451 | }; // namespace deep_gemm 452 | 453 | #pragma clang diagnostic pop 454 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/mma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef __CUDACC_RTC__ 4 | #include 5 | #endif 6 | 7 | #include 8 | #include 9 | 10 | #include "utils.cuh" 11 | 12 | namespace deep_gemm { 13 | 14 | template 15 | struct SM90_U32x2_STSM_N { 16 | __device__ __forceinline__ static void 17 | copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { 18 | const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; 19 | asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" 20 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); 21 | } 22 | }; 23 | 24 | template 25 | struct SM90_U32x4_STSM_N { 26 | __device__ __forceinline__ static void 27 | copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { 28 | const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), 29 | *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; 30 | asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" 31 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); 32 | } 33 | }; 34 | 35 | __forceinline__ __device__ void warpgroup_arrive() { 36 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 37 | } 38 | 39 | __forceinline__ __device__ void warpgroup_commit_batch() { 40 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 41 | } 42 | 43 | __forceinline__ __device__ void warpgroup_fence_operand(float& reg) { 44 | asm volatile("" : "+f"(reg) :: "memory"); 45 | } 46 | 47 | __forceinline__ __device__ uint32_t get_lane_id() { 48 | uint32_t lane_id; 49 | asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); 50 | return lane_id; 51 | } 52 | 53 | __device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { 54 | uint32_t ret; 55 | asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 56 | return ret; 57 | } 58 | 59 | __device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { 60 | int4 ret; 61 | asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); 62 | return ret; 63 | } 64 | 65 | __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { 66 | float ret; 67 | asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); 68 | return ret; 69 | } 70 | 71 | __device__ __forceinline__ void st_shared(const float* ptr, float val) { 72 | asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); 73 | } 74 | 75 | __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { 76 | asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); 77 | } 78 | 79 | template 80 | __device__ void warpgroup_wait() { 81 | DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); 82 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); 83 | } 84 | 85 | union GmmaDescriptor { 86 | __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} 87 | 88 | __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} 89 | 90 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} 91 | 92 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} 93 | 94 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { 95 | desc_ = t.desc_; 96 | return *this; 97 | } 98 | 99 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { 100 | desc_ = t.desc_; 101 | return *this; 102 | } 103 | 104 | uint64_t desc_; 105 | uint32_t reg32_[2]; 106 | uint16_t reg16_[4]; 107 | 108 | struct { 109 | uint16_t start_address_: 14, : 2; 110 | uint16_t leading_byte_offset_: 14, : 2; 111 | uint16_t stride_byte_offset_: 14, : 2; 112 | uint8_t : 1, base_offset_: 3, : 4; 113 | uint8_t : 6, layout_type_: 2; 114 | } bitfield; 115 | 116 | // Decay to an `uint64_t` 117 | __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } 118 | }; 119 | 120 | template 121 | __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, 122 | int leading_byte_offset = 0, 123 | int stride_byte_offset = 1024) { 124 | GmmaDescriptor desc; 125 | auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 126 | desc.bitfield.start_address_ = uint_ptr >> 4; 127 | desc.bitfield.layout_type_ = layout_type; 128 | desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; 129 | desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; 130 | desc.bitfield.base_offset_ = 0; 131 | return desc; 132 | } 133 | 134 | template 135 | struct FP8MMA { 136 | 137 | template 138 | __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { 139 | using namespace cute::SM90::GMMA; 140 | MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); 141 | } 142 | 143 | __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 144 | call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); 145 | } 146 | 147 | static constexpr int M = 64; 148 | static constexpr int N = N_; 149 | static constexpr int K = 32; 150 | static constexpr int kNumAccum = M * N / 128; 151 | }; 152 | 153 | template 154 | struct FP8MMASelector { 155 | 156 | static constexpr auto select_mma() { 157 | using namespace cute::SM90::GMMA; 158 | if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); 159 | if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); 160 | if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); 161 | if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); 162 | if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); 163 | if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); 164 | if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); 165 | if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); 166 | if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); 167 | if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); 168 | if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); 169 | if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); 170 | if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); 171 | if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); 172 | if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); 173 | if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); 174 | if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); 175 | if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); 176 | if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); 177 | } 178 | 179 | static constexpr auto select_type() { 180 | return FP8MMA(); 181 | } 182 | 183 | using type = decltype(select_type()); 184 | }; 185 | 186 | } // namespace deep_gemm 187 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/nvrtc_std.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 3 | * All rights reserved. SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #ifdef __CUDACC_RTC__ 21 | 22 | using int8_t = signed char; 23 | using uint8_t = unsigned char; 24 | using int16_t = signed short; 25 | using uint16_t = unsigned short; 26 | using int32_t = signed int; 27 | using uint32_t = unsigned int; 28 | using int64_t = signed long long; 29 | using uint64_t = unsigned long long; 30 | using cuuint64_t = unsigned long long; 31 | 32 | #ifndef CU_TENSOR_MAP_NUM_QWORDS 33 | #define CU_TENSOR_MAP_NUM_QWORDS 16 34 | 35 | struct CUtensorMap_st { 36 | #if defined(__cplusplus) && (__cplusplus >= 201103L) 37 | alignas(64) 38 | #elif __STDC_VERSION__ >= 201112L 39 | _Alignas(64) 40 | #endif 41 | cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; 42 | }; 43 | 44 | using CUtensorMap = CUtensorMap_st; 45 | #endif 46 | 47 | namespace std { 48 | 49 | template struct integral_constant { 50 | static constexpr T value = v; 51 | 52 | using value_type = T; 53 | using type = integral_constant; 54 | 55 | __device__ constexpr operator value_type() const noexcept { return value; } 56 | 57 | __device__ constexpr value_type operator()() const noexcept { return value; } 58 | }; 59 | 60 | using false_type = integral_constant; 61 | using true_type = integral_constant; 62 | 63 | template struct is_same : false_type {}; 64 | 65 | template struct is_same : true_type {}; 66 | 67 | template 68 | inline constexpr bool is_same_v = is_same::value; 69 | 70 | namespace index_sequence_impl { 71 | 72 | // Based on https://stackoverflow.com/a/32223343/11717224 73 | template struct index_sequence { 74 | using type = index_sequence; 75 | using value_type = size_t; 76 | static constexpr size_t size() noexcept { return sizeof...(Ints); } 77 | }; 78 | 79 | template struct _merge_and_renumber; 80 | 81 | template 82 | struct _merge_and_renumber, index_sequence> 83 | : index_sequence {}; 84 | 85 | template 86 | struct make_index_sequence 87 | : _merge_and_renumber::type, 88 | typename make_index_sequence::type> {}; 89 | 90 | template <> struct make_index_sequence<0> : index_sequence<> {}; 91 | template <> struct make_index_sequence<1> : index_sequence<0> {}; 92 | 93 | } // namespace index_sequence_impl 94 | 95 | template 96 | using index_sequence = index_sequence_impl::index_sequence; 97 | 98 | template 99 | using make_index_sequence = index_sequence_impl::make_index_sequence; 100 | 101 | } // namespace std 102 | 103 | #endif 104 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/scheduler.cuh: -------------------------------------------------------------------------------- 1 | #include "utils.cuh" 2 | 3 | namespace deep_gemm { 4 | 5 | enum class GemmType { 6 | Normal, 7 | GroupedContiguous, 8 | GroupedMasked 9 | }; 10 | 11 | #pragma clang diagnostic push 12 | #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" 13 | template 19 | struct Scheduler { 20 | int current_iter = -1; 21 | uint32_t num_aligned_m_blocks; 22 | 23 | // For normal GEMM 24 | // Maybe not used in the masked grouped GEMM 25 | uint32_t num_blocks; 26 | uint32_t num_blocks_in_group; 27 | bool is_peer_cta_alive = true; 28 | 29 | // For grouped GEMM 30 | int* grouped_layout; 31 | 32 | // Only used for masked layout 33 | uint32_t curr_group_idx, curr_cumsum; 34 | 35 | __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, 36 | int* grouped_layout = nullptr) { 37 | num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); 38 | if constexpr (kGemmType == GemmType::Normal) { 39 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 40 | } else if (kGemmType == GemmType::GroupedContiguous) { 41 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 42 | this->grouped_layout = grouped_layout; 43 | } else if (kGemmType == GemmType::GroupedMasked) { 44 | curr_group_idx = curr_cumsum = 0; 45 | this->grouped_layout = grouped_layout; 46 | } 47 | } 48 | 49 | __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { 50 | if (num_blocks_in_group == 1) 51 | return false; 52 | if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { 53 | return true; 54 | } else { 55 | DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); 56 | if constexpr (kIsTMAMulticastOnA) { 57 | return true; 58 | } else { 59 | auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); 60 | auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); 61 | return group_idx == peer_group_idx; 62 | } 63 | } 64 | } 65 | 66 | __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, 67 | uint32_t& m_block_idx, uint32_t& n_block_idx) { 68 | DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); 69 | 70 | // Swizzle for better L2 usages 71 | auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; 72 | auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; 73 | auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; 74 | auto group_idx = block_idx / num_blocks_per_group; 75 | auto first_block_idx = group_idx * kNum1DBlocksPerGroup; 76 | auto in_group_idx = block_idx % num_blocks_per_group; 77 | num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); 78 | 79 | // Fix unaligned TMA multicast 80 | if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { 81 | if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { 82 | num_blocks_in_group = num_blocks_in_group ^ 1; 83 | } else { 84 | in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; 85 | first_block_idx += num_blocks_in_group ^ 1; 86 | num_blocks_in_group = 1; 87 | } 88 | } 89 | 90 | // Convert to final M/N block indices 91 | if constexpr (kIsTMAMulticastOnA) { 92 | m_block_idx = in_group_idx / num_blocks_in_group; 93 | n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; 94 | } else { 95 | m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; 96 | n_block_idx = in_group_idx / num_blocks_in_group; 97 | } 98 | } 99 | 100 | template 101 | __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, 102 | const uint32_t& block_idx, const uint32_t& m_block_idx=0) { 103 | if constexpr (kGemmType == GemmType::Normal) { 104 | return block_idx * block_size; 105 | } else if constexpr (kGemmType == GemmType::GroupedContiguous) { 106 | auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); 107 | return offset * shape_dim + block_idx * block_size; 108 | } else if constexpr (kGemmType == GemmType::GroupedMasked) { 109 | return curr_group_idx * shape_dim + block_idx * block_size; 110 | } 111 | } 112 | 113 | __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { 114 | const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; 115 | 116 | if constexpr (kGemmType == GemmType::GroupedMasked) { 117 | uint32_t num_m_blocks; 118 | while (true) { 119 | // End of the task 120 | if (curr_group_idx == kNumGroups) 121 | return false; 122 | 123 | // Within the current group 124 | num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); 125 | auto current_m_block_cumsum = curr_cumsum + num_m_blocks; 126 | if (next_block_idx < current_m_block_cumsum * kNumNBlocks) 127 | break; 128 | 129 | // Move to check the next group 130 | curr_group_idx ++, curr_cumsum = current_m_block_cumsum; 131 | } 132 | 133 | get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); 134 | } else { 135 | if (next_block_idx >= num_blocks) 136 | return false; 137 | 138 | // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned 139 | is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) 140 | num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) 141 | (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound 142 | get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); 143 | } 144 | return true; 145 | } 146 | }; 147 | 148 | #pragma clang diagnostic pop 149 | 150 | } // namespace deep_gemm 151 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/tma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils.cuh" 4 | 5 | namespace deep_gemm { 6 | 7 | // TODO: move this function to other files 8 | __device__ __forceinline__ void 9 | tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, 10 | int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { 11 | constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); 12 | if (num_tma_multicast == 1) { 13 | cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); 14 | } else if (cute::block_rank_in_cluster() == 0) { 15 | cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); 16 | } 17 | } 18 | 19 | } // namespace deep_gemm 20 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __CLION_IDE__ 4 | 5 | __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { 6 | asm volatile("trap;"); 7 | } 8 | 9 | #define printf host_device_printf 10 | #endif 11 | 12 | #ifndef DG_DEVICE_ASSERT 13 | #define DG_DEVICE_ASSERT(cond) \ 14 | do { \ 15 | if (not (cond)) { \ 16 | printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ 17 | asm("trap;"); \ 18 | } \ 19 | } while (0) 20 | #endif 21 | 22 | #ifndef DG_STATIC_ASSERT 23 | #define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) 24 | #endif 25 | 26 | template 27 | __device__ __host__ constexpr T ceil_div(T a, T b) { 28 | return (a + b - 1) / b; 29 | } 30 | 31 | template 32 | __device__ __host__ constexpr T constexpr_gcd(T a, T b) { 33 | return b == 0 ? a : constexpr_gcd(b, a % b); 34 | } 35 | -------------------------------------------------------------------------------- /deep_gemm/jit/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler 2 | from .runtime import Runtime 3 | -------------------------------------------------------------------------------- /deep_gemm/jit/compiler.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import os 4 | import re 5 | import subprocess 6 | import time 7 | import uuid 8 | from typing import List, Tuple, Type 9 | 10 | import cuda.bindings 11 | import cuda.bindings.nvrtc as nvrtc 12 | from torch.utils.cpp_extension import CUDA_HOME 13 | 14 | from . import interleave_ffma 15 | from .runtime import Runtime, RuntimeCache 16 | 17 | runtime_cache = RuntimeCache() 18 | 19 | 20 | def hash_to_hex(s: str) -> str: 21 | md5 = hashlib.md5() 22 | md5.update(s.encode('utf-8')) 23 | return md5.hexdigest()[0:12] 24 | 25 | 26 | @functools.lru_cache(maxsize=None) 27 | def get_jit_include_dir() -> str: 28 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') 29 | 30 | 31 | @functools.lru_cache(maxsize=None) 32 | def get_deep_gemm_version() -> str: 33 | md5 = hashlib.md5() 34 | 35 | # Update include directories 36 | include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') 37 | assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' 38 | for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): 39 | with open(os.path.join(include_dir, filename), 'rb') as f: 40 | md5.update(f.read()) 41 | 42 | # Update `interleave_ffma.py` 43 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: 44 | md5.update(f.read()) 45 | return md5.hexdigest()[0:12] 46 | 47 | 48 | @functools.lru_cache(maxsize=None) 49 | def get_nvcc_compiler() -> Tuple[str, str]: 50 | paths = [] 51 | if os.getenv('DG_JIT_NVCC_COMPILER'): 52 | paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) 53 | paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) 54 | 55 | # Try to find the first available NVCC compiler 56 | least_version_required = '12.3' 57 | version_pattern = re.compile(r'release (\d+\.\d+)') 58 | for path in paths: 59 | if os.path.exists(path): 60 | command = [path, '--version'] 61 | result = subprocess.run(command, stdout=subprocess.PIPE, 62 | stderr=subprocess.PIPE, text=True) 63 | match = version_pattern.search(result.stdout) 64 | version = match.group(1) 65 | assert match, f'Cannot get the version of NVCC compiler {path}' 66 | assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' 67 | return path, version 68 | raise RuntimeError('Cannot find any available NVCC compiler') 69 | 70 | 71 | @functools.lru_cache(maxsize=None) 72 | def get_default_user_dir(): 73 | if 'DG_JIT_CACHE_DIR' in os.environ: 74 | path = os.getenv('DG_JIT_CACHE_DIR') 75 | os.makedirs(path, exist_ok=True) 76 | return path 77 | return os.path.join(os.path.expanduser('~'), '.deep_gemm') 78 | 79 | 80 | @functools.lru_cache(maxsize=None) 81 | def get_tmp_dir(): 82 | return os.path.join(get_default_user_dir(), 'tmp') 83 | 84 | 85 | @functools.lru_cache(maxsize=None) 86 | def get_cache_dir(): 87 | return os.path.join(get_default_user_dir(), 'cache') 88 | 89 | 90 | def make_tmp_dir(): 91 | tmp_dir = get_tmp_dir() 92 | os.makedirs(tmp_dir, exist_ok=True) 93 | return tmp_dir 94 | 95 | 96 | def put(path, data): 97 | # Write and do POSIX atomic replace 98 | tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') 99 | with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: 100 | f.write(data) 101 | os.replace(tmp_file_path, path) 102 | 103 | 104 | class Compiler: 105 | @classmethod 106 | def signature(cls) -> str: 107 | pass 108 | 109 | @staticmethod 110 | def __version__() -> Tuple[int, int]: 111 | pass 112 | 113 | @classmethod 114 | def compile(cls, name: str, code: str, target_path: str) -> None: 115 | pass 116 | 117 | @staticmethod 118 | def flags() -> List[str]: 119 | cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) 120 | return [f'-std=c++{cpp_standard}', 121 | '--ptxas-options=--register-usage-level=10' + 122 | (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), 123 | # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases 124 | '--diag-suppress=39,161,174,177,940'] 125 | 126 | @staticmethod 127 | def include_dirs() -> List[str]: 128 | return [get_jit_include_dir()] 129 | 130 | @classmethod 131 | def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: 132 | # Compiler flags 133 | flags = cls.flags() 134 | 135 | # Build signature 136 | enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) 137 | signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' 138 | name = f'kernel.{name}.{hash_to_hex(signature)}' 139 | path = os.path.join(get_cache_dir(), name) 140 | 141 | # Check runtime cache or file system hit 142 | global runtime_cache 143 | cached_runtime = runtime_cache.get(path, runtime_cls) 144 | if cached_runtime is not None: 145 | if int(os.getenv('DG_JIT_DEBUG', 0)): 146 | print(f'Using cached JIT runtime {name} during build') 147 | return cached_runtime 148 | 149 | # Compile into a temporary CU file 150 | os.makedirs(path, exist_ok=True) 151 | cubin_path = os.path.join(path, 'kernel.cubin') 152 | tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') 153 | 154 | start_time = time.time() 155 | cls.compile(name, code, tmp_cubin_path) 156 | end_time = time.time() 157 | elapsed_time = end_time - start_time 158 | if int(os.getenv('DG_JIT_DEBUG', 0)): 159 | print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') 160 | 161 | # Interleave FFMA reuse 162 | if enable_sass_opt: 163 | interleave_ffma.process(tmp_cubin_path) 164 | 165 | # Atomic replace files 166 | os.replace(tmp_cubin_path, cubin_path) 167 | 168 | # Put cache and return 169 | runtime = runtime_cls(path) 170 | runtime_cache[path] = runtime 171 | return runtime 172 | 173 | 174 | class NVCCCompiler(Compiler): 175 | @staticmethod 176 | def __version__() -> Tuple[int, int]: 177 | _, version = get_nvcc_compiler() 178 | major, minor = map(int, version.split('.')) 179 | return major, minor 180 | 181 | @classmethod 182 | def signature(cls) -> str: 183 | return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' 184 | 185 | @classmethod 186 | def flags(cls) -> List[str]: 187 | cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] 188 | return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], 189 | '-gencode=arch=compute_90a,code=sm_90a', 190 | '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', 191 | f'--compiler-options={",".join(cxx_flags)}'] 192 | 193 | @classmethod 194 | def compile(cls, name: str, code: str, target_path: str) -> None: 195 | # Write the code 196 | path = os.path.join(get_cache_dir(), name) 197 | src_path = os.path.join(path, 'kernel.cu') 198 | put(src_path, code) 199 | command = [get_nvcc_compiler()[0], 200 | src_path, '-o', target_path, 201 | *cls.flags()] 202 | if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): 203 | print(f'Compiling JIT runtime {name} with command {command}') 204 | 205 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 206 | if result.returncode != 0: 207 | print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') 208 | assert False, f'Failed to compile {src_path}' 209 | 210 | 211 | class NVRTCCompiler(Compiler): 212 | @staticmethod 213 | def __version__() -> Tuple[int, int]: 214 | res, major, minor = nvrtc.nvrtcVersion() 215 | if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: 216 | # Failed to get the actual NVRTC version, use cuda-bindings version instead 217 | major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) 218 | return major, minor 219 | 220 | @classmethod 221 | def signature(cls) -> str: 222 | return f'nvrtc+{cls.__version__()}' 223 | 224 | @staticmethod 225 | def include_dirs() -> List[str]: 226 | if CUDA_HOME is None: 227 | raise RuntimeError('CUDA_HOME is required for NVRTC compilation') 228 | return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] 229 | 230 | @classmethod 231 | def flags(cls) -> List[str]: 232 | flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], 233 | '--gpu-architecture=sm_90a', '-default-device'] 234 | # NOTES: PCH is vital for compilation speed 235 | if cls.__version__() >= (12, 8): 236 | flags += ['--pch'] 237 | if int(os.getenv('DG_JIT_DEBUG', 0)): 238 | flags += ['--pch-verbose=true'] 239 | return flags 240 | 241 | @classmethod 242 | def compile(cls, name: str, code: str, target_path: str) -> None: 243 | # Create program 244 | code_bytes = bytes(code, 'utf-8') 245 | result, program = nvrtc.nvrtcCreateProgram( 246 | code_bytes, bytes(name, 'utf-8'), 0, [], []) 247 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' 248 | 249 | # Compile 250 | options = [bytes(flag, 'utf-8') for flag in cls.flags()] 251 | if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): 252 | print(f'Compiling JIT runtime {name} with options: {options}') 253 | compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] 254 | 255 | # Print compiler log 256 | if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: 257 | result, log_size = nvrtc.nvrtcGetProgramLogSize(program) 258 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' 259 | 260 | log_bytes = bytes(log_size) 261 | result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] 262 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' 263 | print(f'Compiler log: {log_bytes.decode("utf-8")}') 264 | 265 | # Exit if failed 266 | assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' 267 | 268 | # Create CUBIN 269 | result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) 270 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' 271 | cubin_bytes = bytes(cubin_size) 272 | result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] 273 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' 274 | 275 | # Write into the file system 276 | put(target_path, cubin_bytes) 277 | 278 | # Destroy handler 279 | assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' 280 | 281 | 282 | def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: 283 | compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler 284 | return compiler_cls.build(name, code, runtime_cls=runtime_cls) 285 | -------------------------------------------------------------------------------- /deep_gemm/jit/interleave_ffma.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mmap 3 | import os 4 | import re 5 | import subprocess 6 | from torch.utils.cpp_extension import CUDA_HOME 7 | 8 | 9 | def run_cuobjdump(file_path): 10 | command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] 11 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 12 | assert result.returncode == 0 13 | return result.stdout 14 | 15 | 16 | def extract_ffma(sass): 17 | lines = sass.splitlines() 18 | collected = [] 19 | current = [] 20 | 21 | arch_name, func_name = 'N/A', 'N/A' 22 | skip_next_line = False 23 | for line in lines: 24 | if 'code for' in line: 25 | arch_name = line.lstrip().lstrip('code for ').rstrip() 26 | elif 'Function :' in line: 27 | func_name = line.lstrip().lstrip('Function :').rstrip() 28 | elif 'FFMA' in line: 29 | current.append(line) 30 | skip_next_line = True 31 | elif skip_next_line: 32 | current.append(line) 33 | skip_next_line = False 34 | else: 35 | if len(current) >= 16: 36 | assert len(current) % 2 == 0 37 | collected.append((f'{arch_name}::{func_name}', current)) 38 | current = [] 39 | 40 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 41 | print(f'Found {len(collected)} FFMA segments') 42 | return collected 43 | 44 | 45 | def extract_hex_from_line(line): 46 | match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) 47 | assert match 48 | return int(match.group(1), 16) 49 | 50 | 51 | def validate(m, offset, le_bytes, num_lines): 52 | assert len(le_bytes) == num_lines // 2 53 | assert m[offset:offset + 16] == le_bytes[0] 54 | for i in range(1, num_lines // 2): 55 | if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: 56 | return False 57 | return True 58 | 59 | 60 | def parse_registers(line): 61 | line = re.sub(r'/\*.*?\*/', '', line) 62 | line = line.replace(';', '') 63 | tokens = line.strip().split(',') 64 | registers = [] 65 | for token in tokens: 66 | token = token.strip() 67 | words = token.split() 68 | for word in words: 69 | if word.startswith('R'): 70 | reg = word.split('.')[0] 71 | registers.append(reg) 72 | return registers 73 | 74 | 75 | def modify_segment(m, name, ffma_lines): 76 | num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2 77 | assert num_lines % 2 == 0 78 | 79 | le_bytes, new_le_bytes = [], [] 80 | reused_list = [] 81 | dst_reg_set = set() 82 | last_reused, last_dst_reg = False, '' 83 | num_changed = 0 84 | for i in range(num_lines // 2): 85 | dst_reg = parse_registers(ffma_lines[i * 2])[-2] 86 | low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] 87 | low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) 88 | le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 89 | reused = (high_hex & 0x0800000000000000) != 0 90 | if reused: 91 | is_first_occurred = dst_reg not in dst_reg_set 92 | if is_first_occurred or (last_reused and dst_reg == last_dst_reg): 93 | # Modify the `reuse` and `yield` bits 94 | assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' 95 | high_hex ^= 0x0800200000000000 96 | reused = False 97 | num_changed += 1 98 | else: 99 | reused_list.append(i) 100 | dst_reg_set.add(dst_reg) 101 | new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 102 | last_reused, last_dst_reg = reused, dst_reg 103 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 104 | print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') 105 | 106 | # Find the offset 107 | offsets = [] 108 | offset = m.find(le_bytes[0]) 109 | while offset != -1: 110 | offsets.append(offset) 111 | offset = m.find(le_bytes[0], offset + 1) 112 | offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) 113 | 114 | # Replace with `new_le_bytes` 115 | for offset in offsets: 116 | for i in range(num_lines // 2): 117 | m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] 118 | 119 | 120 | def process(path): 121 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 122 | print(f'Processing {path}') 123 | output = run_cuobjdump(path) 124 | segments = extract_ffma(output) 125 | with open(path, 'r+b') as f: 126 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) 127 | for segment in segments: 128 | modify_segment(mm, *segment) 129 | mm.close() 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') 134 | parser.add_argument('--so', help='Path to the SO file') 135 | args = parser.parse_args() 136 | 137 | process(args.so) 138 | -------------------------------------------------------------------------------- /deep_gemm/jit/runtime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | import cuda.bindings.driver as cbd 5 | 6 | from typing import List, Optional, Type 7 | from torch.utils.cpp_extension import CUDA_HOME 8 | 9 | 10 | class Runtime: 11 | def __init__(self, path: str, args: List[str] = None) -> None: 12 | self.path = path 13 | self.lib = None 14 | self.kernel = None 15 | self.args = args 16 | assert self.is_path_valid(self.path) 17 | 18 | @staticmethod 19 | def is_path_valid(path: str) -> bool: 20 | # Exists and is a directory 21 | if not os.path.exists(path) or not os.path.isdir(path): 22 | return False 23 | 24 | # Contains all necessary files 25 | files = ['kernel.cubin'] 26 | return all(os.path.exists(os.path.join(path, file)) for file in files) 27 | 28 | @staticmethod 29 | def generate(**kwargs) -> str: 30 | raise NotImplemented 31 | 32 | @staticmethod 33 | def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: 34 | raise NotImplemented 35 | 36 | def __call__(self, **kwargs) -> cbd.CUresult: 37 | # Load CUBIN 38 | if self.kernel is None: 39 | start_time = time.time_ns() 40 | 41 | # Load CUBIN 42 | path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') 43 | result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) 44 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' 45 | 46 | # Extract the kernel name 47 | # TODO: use `cuda-bindings` API to do this (requires at least 12.8) 48 | command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] 49 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 50 | assert result.returncode == 0 51 | kernel_names = [line.split()[-1] for line in result.stdout.splitlines() 52 | if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line] 53 | assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' 54 | 55 | # Load kernel from the library 56 | result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) 57 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' 58 | 59 | end_time = time.time_ns() 60 | elapsed_time = (end_time - start_time) / 1e6 61 | if int(os.getenv('DG_JIT_DEBUG', 0)): 62 | print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') 63 | 64 | # noinspection PyArgumentList 65 | return self.launch(self.kernel, *[kwargs[arg] for arg in self.args]) 66 | 67 | def __del__(self) -> None: 68 | if self.lib is not None: 69 | res = cbd.cuLibraryUnload(self.lib)[0] 70 | if res != cbd.CUresult.CUDA_SUCCESS: 71 | raise Exception(f'Failed to unload library {self.path}: {res}') 72 | 73 | 74 | class RuntimeCache: 75 | def __init__(self) -> None: 76 | self.cache = {} 77 | 78 | def __setitem__(self, path: str, runtime: Runtime) -> None: 79 | self.cache[path] = runtime 80 | 81 | def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: 82 | # In Python runtime 83 | if path in self.cache: 84 | return self.cache[path] 85 | 86 | # Already compiled 87 | if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): 88 | runtime = runtime_cls(path) 89 | self.cache[path] = runtime 90 | return runtime 91 | return None 92 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemm import gemm_fp8_fp8_bf16_nt 2 | from .m_grouped_gemm import ( 3 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 4 | m_grouped_gemm_fp8_fp8_bf16_nt_masked 5 | ) 6 | from .utils import ( 7 | ceil_div, set_num_sms, get_num_sms, 8 | get_col_major_tma_aligned_tensor, 9 | get_m_alignment_for_contiguous_layout 10 | ) 11 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/gemm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import lru_cache 4 | from typing import Tuple 5 | 6 | from .runtime import ( 7 | FP8GemmRuntime, GemmType, 8 | make_2d_tma_a_desc, make_2d_tma_b_desc, 9 | make_2d_tma_d_desc, make_2d_tma_scales_a_desc) 10 | from .tuner import jit_tuner 11 | from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout 12 | 13 | 14 | def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, 15 | require_divisible: bool = False) -> bool: 16 | divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible 17 | return divisible and num_sms % num_tma_multicast == 0 18 | 19 | 20 | def get_swizzle_mode(block_n: int) -> int: 21 | # TODO: remove some candidates if slow 22 | elem_size = 2 23 | for mode_bytes in (128, 64, 32): 24 | if (block_n * elem_size) % mode_bytes == 0: 25 | return mode_bytes 26 | return 0 27 | 28 | 29 | def get_block_n_padding_for_smem_d(block_n: int) -> int: 30 | # NOTES: padding is for solving bank conflicts, but wastes shared memory space 31 | elem_size, requirement = 2, (4, 8) 32 | bank_stride = (block_n * elem_size) // 4 33 | padding = (requirement[0] - bank_stride) % requirement[1] 34 | return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size 35 | 36 | 37 | def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: 38 | # Try swizzle first, as it does not waste shared memory 39 | swizzle_mode = get_swizzle_mode(block_n) 40 | block_n_padding = get_block_n_padding_for_smem_d( 41 | block_n) if swizzle_mode == 0 else 0 42 | 43 | smem_d = block_m * (block_n + block_n_padding) * 2 44 | smem_a_per_stage = block_m * block_k 45 | smem_scales_a_per_stage = block_m * 4 46 | smem_b_per_stage = block_n * block_k 47 | smem_scales_b = ceil_div(k, block_k) * 4 48 | smem_barrier = num_stages * 8 * 2 49 | 50 | smem_size = 0 51 | smem_size += smem_d 52 | smem_size += num_stages * smem_a_per_stage 53 | smem_size += num_stages * smem_scales_a_per_stage 54 | smem_size += num_stages * smem_b_per_stage 55 | smem_size += ceil_div(smem_scales_b * (1 if block_k % 56 | block_n == 0 else 2), 8) * 8 57 | smem_size += smem_barrier 58 | 59 | # Swizzle and padding are not compatible 60 | assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 61 | 62 | return smem_size, swizzle_mode, block_n_padding 63 | 64 | 65 | @lru_cache(maxsize=None) 66 | def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, 67 | is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ 68 | Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: 69 | if not is_grouped_contiguous: 70 | block_ms = (64, 128, 256) 71 | else: 72 | block_ms = (get_m_alignment_for_contiguous_layout(), ) 73 | block_ns = tuple(range(16, 129, 8)) + (144, 160, ) 74 | 75 | fix_wave_saturate = lambda x: num_sms if x == 0 else x 76 | get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) 77 | get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) 78 | 79 | # Decide block sizes by waves 80 | best_block_m, best_block_n = None, None 81 | for block_m in block_ms: 82 | # NOTES: the block sizes cannot be too large, so at least one dim less than 128 83 | for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): 84 | success = False 85 | num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) 86 | if best_block_m is None or best_block_n is None: 87 | success = True 88 | elif num_waves < best_num_waves: 89 | success = True 90 | elif num_waves == best_num_waves: 91 | # Check last wave utilization 92 | util = get_last_wave_util(block_m, block_n) 93 | best_util = get_last_wave_util(best_block_m, best_block_n) 94 | success = util > best_util 95 | if util == best_util: 96 | # Case 1: same `block_m`, smaller `block_n` (wasted) 97 | success |= block_m == best_block_m and block_n < best_block_n 98 | # Case 2: same `block_n`, smaller `block_m` (wasted) 99 | success |= block_n == best_block_n and block_m < best_block_m 100 | # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better 101 | success |= block_m != best_block_m and block_n > best_block_n 102 | best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) 103 | assert best_block_m is not None and best_block_n is not None 104 | 105 | # Always pick the longest one 106 | # NOTES: for double B scales, the best number of stages may be reduced 107 | best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 108 | stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))) 109 | if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: 110 | # Unrolling both stages and `num_former_iters` will cause large code size 111 | stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) 112 | for num_stages in stage_candidates: 113 | best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) 114 | if best_smem_config[0] <= sm90_capacity: 115 | best_num_stages = num_stages 116 | break 117 | assert best_smem_config is not None 118 | assert best_num_stages is not None 119 | 120 | # Decide the number of TMA multicasts and whether broadcast on A 121 | best_tma_multicast_config = (1, True) 122 | 123 | # Try to multicast on the larger block side first 124 | # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even 125 | is_multicast_legal = { 126 | 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), 127 | 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, 128 | } 129 | for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): 130 | if m >= 512 and is_multicast_legal[i]: 131 | best_tma_multicast_config = (2, i == 'A') 132 | break 133 | 134 | # Recompute the minimal number of SMs required 135 | # NOTES: less L2 cache usage and less GPU frequency drop 136 | num_waves = get_num_waves(best_block_m, best_block_n) 137 | num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) 138 | num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] 139 | assert num_min_sms <= num_sms 140 | 141 | return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config 142 | 143 | 144 | def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 145 | rhs: Tuple[torch.Tensor, torch.Tensor], 146 | out: torch.Tensor) -> None: 147 | """ 148 | Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 149 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 150 | RHS and RHS scaling factors are required to be transposed. 151 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 152 | this function will do a transposing with a set of slow PyTorch operations. 153 | 154 | Arguments: 155 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, 156 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. 157 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, 158 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. 159 | out: the BF16 output tensor of shape `[m, n]`, representing the result. 160 | """ 161 | lhs, lhs_scales = lhs 162 | rhs, rhs_scales = rhs 163 | m, k = lhs.shape 164 | n, k_ = rhs.shape 165 | m_, n_ = out.shape 166 | 167 | assert n % 64 == 0 and k % 128 == 0 168 | 169 | # Type and shape checks 170 | assert m == m_ and n == n_ and k == k_ 171 | assert n > 0 and k > 0 172 | assert lhs_scales.shape == (m, (k + 127) // 128) 173 | assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) 174 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 175 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 176 | assert out.dtype == torch.bfloat16 177 | assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() 178 | 179 | # LHS scales must be transposed for TMA loads, but not for RHS scales 180 | # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels 181 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 182 | assert rhs_scales.is_contiguous() 183 | 184 | # Do nothing if `m` is zero 185 | if m == 0: 186 | return 187 | 188 | # Auto-tuning with compilation 189 | num_sms = get_num_sms() 190 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 191 | m, n, k, 1, num_sms) 192 | block_k = 128 193 | num_tma_threads = 128 194 | num_math_threads_per_group = 128 195 | 196 | tensor_map_a = make_2d_tma_a_desc( 197 | GemmType.Normal, lhs, m, k, block_m, block_k, 1) 198 | tensor_map_b = make_2d_tma_b_desc( 199 | GemmType.Normal, rhs, k, n, block_k, block_n, 1) 200 | tensor_map_d = make_2d_tma_d_desc( 201 | GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) 202 | tensor_map_scales_a = make_2d_tma_scales_a_desc( 203 | GemmType.Normal, lhs_scales, m, k, block_m, block_k) 204 | 205 | kwargs = { 206 | 'GEMM_TYPE': GemmType.Normal, 207 | 'NUM_TMA_THREADS': num_tma_threads, 208 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 209 | 'M': m, 210 | 'NUM_GROUPS': 1, 211 | 'BLOCK_K': block_k, 212 | 'GMEM_D': out, 213 | 'SCALES_B': rhs_scales, 214 | 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), 215 | 'NUM_SMS': num_sms, 216 | 'SMEM_SIZE': smem_config[0], 217 | 'TENSOR_MAP_A': tensor_map_a, 218 | 'TENSOR_MAP_B': tensor_map_b, 219 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 220 | 'TENSOR_MAP_D': tensor_map_d, 221 | 'STREAM': torch.cuda.current_stream().cuda_stream, 222 | } 223 | 224 | runtime, best_keys = jit_tuner.compile_and_tune( 225 | name='gemm_fp8_fp8_bf16_nt', 226 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 227 | 'SWIZZLE_D_MODE': smem_config[1], 228 | 'BLOCK_N_PADDING': smem_config[2], 229 | 'NUM_STAGES': num_stages, 230 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 231 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, 232 | space=(), 233 | kwargs=kwargs, 234 | runtime_cls=FP8GemmRuntime, 235 | ) 236 | 237 | # Run the kernel 238 | runtime(**best_keys, **kwargs) 239 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/m_grouped_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from .gemm import get_best_configs 5 | from .runtime import ( 6 | FP8GemmRuntime, GemmType, 7 | make_2d_tma_a_desc, make_2d_tma_b_desc, 8 | make_2d_tma_d_desc, make_2d_tma_scales_a_desc) 9 | from .tuner import jit_tuner 10 | from .utils import get_col_major_tma_aligned_tensor, get_num_sms 11 | 12 | 13 | def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], 14 | rhs: Tuple[torch.Tensor, torch.Tensor], 15 | out: torch.Tensor, m_indices: torch.Tensor) -> None: 16 | """ 17 | Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 18 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 19 | RHS and RHS scaling factors are required to be transposed. 20 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 21 | this function will do a transposing with a set of slow PyTorch operations. 22 | On the M axis, inputs are grouped into several batches, of which batch sizes aligned to 23 | `get_m_alignment_for_contiguous_layout()` (128). 24 | 25 | Arguments: 26 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, 27 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. 28 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, 29 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 30 | out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. 31 | m_indices: a tensor of shape `[m_sum]` with type `torch.int`. 32 | `m_indices[i]` records the group which the i-th row of the LHS belongs to, 33 | which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. 34 | Values of `m_indices` in every-m-alignment-block must also be the same. 35 | """ 36 | lhs, lhs_scales = lhs 37 | rhs, rhs_scales = rhs 38 | m, k = lhs.shape 39 | num_groups, n, k_ = rhs.shape 40 | m_, n_ = out.shape 41 | m__ = m_indices.numel() 42 | 43 | # Type and shape checks 44 | assert m == m_ == m__ and k == k_ and n == n_ 45 | assert lhs_scales.shape == (m, (k + 127) // 128) 46 | assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) 47 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 48 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 49 | assert out.dtype == torch.bfloat16 50 | assert m_indices.dtype == torch.int32 51 | assert lhs.is_contiguous() and rhs.is_contiguous() 52 | assert out.is_contiguous() and m_indices.is_contiguous() 53 | 54 | # LHS scales must be transposed for TMA load, but not for RHS scales 55 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 56 | assert rhs_scales.is_contiguous() 57 | 58 | # Do nothing if `m` is zero 59 | if m == 0: 60 | return 61 | 62 | # Auto-tuning with compilation 63 | num_sms = get_num_sms() 64 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 65 | m, n, k, 1, num_sms, is_grouped_contiguous=True) 66 | block_k = 128 67 | num_tma_threads = 128 68 | num_math_threads_per_group = 128 69 | 70 | tensor_map_a = make_2d_tma_a_desc( 71 | GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) 72 | tensor_map_b = make_2d_tma_b_desc( 73 | GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) 74 | tensor_map_d = make_2d_tma_d_desc( 75 | GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1]) 76 | tensor_map_scales_a = make_2d_tma_scales_a_desc( 77 | GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) 78 | 79 | kwargs = { 80 | 'NUM_TMA_THREADS': num_tma_threads, 81 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 82 | 'M': m, 83 | 'BLOCK_K': block_k, 84 | 'GMEM_D': out, 85 | 'SCALES_B': rhs_scales, 86 | 'GROUPED_LAYOUT': m_indices, 87 | 'NUM_SMS': num_sms, 88 | 'SMEM_SIZE': smem_config[0], 89 | 'TENSOR_MAP_A': tensor_map_a, 90 | 'TENSOR_MAP_B': tensor_map_b, 91 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 92 | 'TENSOR_MAP_D': tensor_map_d, 93 | 'STREAM': torch.cuda.current_stream().cuda_stream, 94 | } 95 | 96 | runtime, best_keys = jit_tuner.compile_and_tune( 97 | name='m_grouped_gemm_fp8_fp8_bf16_nt', 98 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 99 | 'SWIZZLE_D_MODE': smem_config[1], 100 | 'BLOCK_N_PADDING': smem_config[2], 101 | 'NUM_GROUPS': num_groups, 102 | 'NUM_STAGES': num_stages, 103 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 104 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 105 | 'GEMM_TYPE': GemmType.GroupedContiguous}, 106 | space=(), 107 | kwargs=kwargs, 108 | runtime_cls=FP8GemmRuntime, 109 | ) 110 | 111 | # Run the kernel 112 | runtime(**best_keys, **kwargs) 113 | 114 | 115 | def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], 116 | rhs: Tuple[torch.Tensor, torch.Tensor], 117 | out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: 118 | """ 119 | Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 120 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 121 | RHS and RHS scaling factors are required to be transposed. 122 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 123 | this function will do a transposing with a set of slow PyTorch operations. 124 | Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch 125 | should be separately transposed. 126 | 127 | Arguments: 128 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, 129 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. 130 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. 131 | The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 132 | out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. 133 | masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute 134 | in the i-th group. 135 | expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, 136 | correctly setting this value may lead to better performance. 137 | """ 138 | lhs, lhs_scales = lhs 139 | rhs, rhs_scales = rhs 140 | num_groups, m, k = lhs.shape 141 | num_groups_, n, k_ = rhs.shape 142 | num_groups__, m_, n_ = out.shape 143 | num_groups___ = masked_m.numel() 144 | 145 | # Type and shape checks 146 | assert num_groups == num_groups_ == num_groups__ == num_groups___ 147 | assert m == m_ and n == n_ and k == k_ 148 | assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 149 | assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) 150 | assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) 151 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 152 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 153 | assert out.dtype == torch.bfloat16 154 | assert masked_m.dtype == torch.int32 155 | assert lhs.is_contiguous() and rhs.is_contiguous() 156 | assert out.is_contiguous() and masked_m.is_contiguous() 157 | 158 | # LHS scales must be transposed for TMA load, but not for RHS scales 159 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 160 | assert rhs_scales.is_contiguous() 161 | 162 | # Auto-tuning with compilation 163 | num_sms = get_num_sms() 164 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 165 | expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) 166 | 167 | # Extra checks for TMA store 168 | if num_groups > 1 and m > block_m: 169 | assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' 170 | 171 | block_k = 128 172 | num_tma_threads = 128 173 | num_math_threads_per_group = 128 174 | 175 | tensor_map_a = make_2d_tma_a_desc( 176 | GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) 177 | tensor_map_b = make_2d_tma_b_desc( 178 | GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) 179 | tensor_map_d = make_2d_tma_d_desc( 180 | GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1]) 181 | tensor_map_scales_a = make_2d_tma_scales_a_desc( 182 | GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) 183 | 184 | kwargs = { 185 | 'NUM_TMA_THREADS': num_tma_threads, 186 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 187 | 'M': m, 188 | 'BLOCK_K': block_k, 189 | 'GMEM_D': out, 190 | 'SCALES_B': rhs_scales, 191 | 'GROUPED_LAYOUT': masked_m, 192 | 'NUM_SMS': num_sms, 193 | 'SMEM_SIZE': smem_config[0], 194 | 'TENSOR_MAP_A': tensor_map_a, 195 | 'TENSOR_MAP_B': tensor_map_b, 196 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 197 | 'TENSOR_MAP_D': tensor_map_d, 198 | 'STREAM': torch.cuda.current_stream().cuda_stream, 199 | } 200 | 201 | runtime, best_keys = jit_tuner.compile_and_tune( 202 | name='m_grouped_gemm_fp8_fp8_bf16_nt', 203 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 204 | 'SWIZZLE_D_MODE': smem_config[1], 205 | 'BLOCK_N_PADDING': smem_config[2], 206 | 'NUM_GROUPS': num_groups, 207 | 'NUM_STAGES': num_stages, 208 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 209 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 210 | 'GEMM_TYPE': GemmType.GroupedMasked}, 211 | space=(), 212 | kwargs=kwargs, 213 | runtime_cls=FP8GemmRuntime, 214 | ) 215 | 216 | # Run the kernel 217 | runtime(**best_keys, **kwargs) 218 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/runtime.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import enum 4 | import torch 5 | import cuda.bindings.driver as cbd 6 | from typing import Any, Dict, Tuple 7 | 8 | from ..jit.runtime import Runtime 9 | 10 | 11 | class Layout(enum.Enum): 12 | RowMajor = 0 13 | ColMajor = 1 14 | 15 | 16 | class GemmType(enum.Enum): 17 | Normal = 0 18 | GroupedContiguous = 1 19 | GroupedMasked = 2 20 | 21 | def __str__(self) -> str: 22 | return { 23 | 0: 'Normal', 24 | 1: 'GroupedContiguous', 25 | 2: 'GroupedMasked', 26 | }[self.value] 27 | 28 | 29 | tmap_type_map: Dict[Any, str] = { 30 | torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 31 | torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, 32 | torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, 33 | torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, 34 | torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 35 | torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, 36 | torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, 37 | torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, 38 | torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 39 | torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 40 | torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 41 | torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 42 | torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 43 | torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 44 | torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 45 | } 46 | 47 | swizzle_type_map = { 48 | 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, 49 | 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, 50 | 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, 51 | 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, 52 | } 53 | 54 | 55 | def get_num_math_warpgroups(block_m: int) -> int: 56 | return 1 if block_m == 64 else 2 57 | 58 | 59 | def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: 60 | assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' 61 | return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads 62 | 63 | 64 | def make_2d_tma_copy_desc(global_address: torch.Tensor, 65 | gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], 66 | stride_in_bytes: cbd.cuuint64_t, 67 | smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], 68 | swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: 69 | tensor_dtype = tmap_type_map[global_address.dtype] 70 | res, tensor_map = cbd.cuTensorMapEncodeTiled( 71 | tensor_dtype, 72 | 2, 73 | global_address.data_ptr(), 74 | gmem_dim, 75 | (stride_in_bytes, ), 76 | smem_dim, 77 | (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), 78 | cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, 79 | swizzle_type, 80 | cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, 81 | cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, 82 | ) 83 | 84 | if res != cbd.CUresult.CUDA_SUCCESS: 85 | raise Exception(f'Failed to encode tensor map: {res}') 86 | return tensor_map 87 | 88 | 89 | def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, 90 | gmem_rows: int, gmem_cols: int, 91 | smem_rows: int, smem_cols: int, 92 | swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: 93 | if layout == Layout.RowMajor: 94 | gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) 95 | smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) 96 | return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) 97 | else: 98 | gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols)) 99 | smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols)) 100 | return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) 101 | 102 | 103 | def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, 104 | shape_m: int, shape_k: int, 105 | block_m: int, block_k: int, 106 | num_groups: int) -> cbd.CUtensorMap: 107 | return make_2d_tma_desc(global_address, Layout.RowMajor, 108 | shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, 109 | block_m, block_k) 110 | 111 | 112 | def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, 113 | shape_k: int, shape_n: int, 114 | block_k: int, block_n: int, 115 | num_groups: int) -> cbd.CUtensorMap: 116 | return make_2d_tma_desc(global_address, Layout.ColMajor, 117 | shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), 118 | block_k, block_n) 119 | 120 | 121 | def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, 122 | shape_m: int, shape_n: int, 123 | block_m: int, block_n: int, 124 | num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap: 125 | # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` 126 | # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required 127 | return make_2d_tma_desc(global_address, Layout.RowMajor, 128 | shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, 129 | block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), 130 | swizzle_type_map[swizzle_mode]) 131 | 132 | 133 | def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: 134 | # Make TMA aligned to 16 bytes 135 | tma_alignment = 16 / global_address.element_size() 136 | shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment 137 | 138 | return make_2d_tma_desc(global_address, Layout.ColMajor, 139 | shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), 140 | block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) 141 | 142 | 143 | class FP8GemmRuntime(Runtime): 144 | def __init__(self, path: str) -> None: 145 | super().__init__(path, [ 146 | 'NUM_TMA_MULTICAST', 147 | 'M', 148 | 'BLOCK_M', 149 | 'GMEM_D', 150 | 'SCALES_B', 151 | 'GROUPED_LAYOUT', 152 | 'NUM_SMS', 153 | 'SMEM_SIZE', 154 | 'TENSOR_MAP_A', 155 | 'TENSOR_MAP_B', 156 | 'TENSOR_MAP_SCALES_A', 157 | 'TENSOR_MAP_D', 158 | 'STREAM', 159 | ]) 160 | 161 | @staticmethod 162 | def generate(**kwargs) -> str: 163 | code = f''' 164 | #ifdef __CUDACC_RTC__ 165 | #include 166 | #else 167 | #include 168 | #include 169 | #endif 170 | 171 | #include 172 | #include 173 | 174 | #include 175 | 176 | using namespace deep_gemm; 177 | 178 | static void __instantiate_kernel() {{ 179 | auto ptr = reinterpret_cast(&fp8_gemm_kernel< 180 | {kwargs['N']}, 181 | {kwargs['K']}, 182 | {kwargs['BLOCK_M']}, 183 | {kwargs['BLOCK_N']}, 184 | {kwargs['BLOCK_K']}, 185 | {kwargs['BLOCK_N_PADDING']}, 186 | {kwargs['SWIZZLE_D_MODE']}, 187 | {kwargs['NUM_GROUPS']}, 188 | {kwargs['NUM_STAGES']}, 189 | {kwargs['NUM_TMA_THREADS']}, 190 | {kwargs['NUM_MATH_THREADS_PER_GROUP']}, 191 | {kwargs['NUM_TMA_MULTICAST']}, 192 | {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, 193 | GemmType::{kwargs['GEMM_TYPE']} 194 | >); 195 | }}; 196 | ''' 197 | if int(os.getenv('DG_JIT_DEBUG', 0)): 198 | print(f'Generated FP8 GEMM code:\n{code}') 199 | return code 200 | 201 | # noinspection PyMethodOverriding 202 | @staticmethod 203 | def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, 204 | block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, 205 | grouped_layout: torch.Tensor, num_sms: int, smem_size: int, 206 | tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, 207 | tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, 208 | stream: cbd.CUstream) -> cbd.CUresult: 209 | num_tma_threads = 128 210 | num_math_threads_per_group = 128 211 | 212 | res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] 213 | if res != cbd.CUresult.CUDA_SUCCESS: 214 | raise Exception(f'Failed to set max dynamic shared memory size: {res}') 215 | 216 | attr_val = cbd.CUlaunchAttributeValue() 217 | attr_val.clusterDim.x = num_tma_multicast 218 | attr_val.clusterDim.y = 1 219 | attr_val.clusterDim.z = 1 220 | attr = cbd.CUlaunchAttribute() 221 | attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION 222 | attr.value = attr_val 223 | 224 | config = cbd.CUlaunchConfig() 225 | config.numAttrs = 1 226 | config.attrs = [attr] 227 | config.gridDimX = num_sms 228 | config.gridDimY = 1 229 | config.gridDimZ = 1 230 | config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) 231 | config.blockDimY = 1 232 | config.blockDimZ = 1 233 | config.sharedMemBytes = smem_size 234 | config.hStream = stream 235 | 236 | arg_values = ( 237 | gmem_d.data_ptr(), 238 | scales_b.data_ptr(), 239 | grouped_layout.data_ptr(), 240 | shape_m, 241 | tensor_map_a, 242 | tensor_map_b, 243 | tensor_map_scales_a, 244 | tensor_map_d, 245 | ) 246 | arg_types = ( 247 | ctypes.c_void_p, 248 | ctypes.c_void_p, 249 | ctypes.c_void_p, 250 | ctypes.c_uint32, 251 | None, 252 | None, 253 | None, 254 | None, 255 | ) 256 | return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) 257 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/tuner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import torch 4 | import cuda.bindings.driver as cbd 5 | from typing import Any, Callable, Dict, Type, Tuple 6 | 7 | from ..jit import build, Runtime 8 | 9 | 10 | class JITTuner: 11 | def __init__(self) -> None: 12 | self.tuned = {} 13 | 14 | def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, 15 | kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: 16 | # NOTES: we always assume the space, template and GPU devices will not change 17 | # NOTES: the function must have no accumulated side effects 18 | keys = {k: keys[k] for k in sorted(keys.keys())} 19 | signature = (name, f'{keys}') 20 | if signature in self.tuned: 21 | if int(os.getenv('DG_JIT_DEBUG', 0)): 22 | print(f'Using cached JIT kernel {name} with keys {keys}') 23 | return self.tuned[signature] 24 | 25 | if int(os.getenv('DG_JIT_DEBUG', 0)): 26 | print(f'Auto-tuning JIT kernel {name} with keys {keys}') 27 | 28 | assert signature not in self.tuned 29 | assert kwargs is not None 30 | space = (dict(), ) if len(space) == 0 else space 31 | 32 | kernels = [] 33 | for tuned_keys in space: 34 | assert isinstance(tuned_keys, dict) 35 | full_keys = copy.deepcopy(keys) 36 | full_keys.update(tuned_keys) 37 | code = runtime_cls.generate(**kwargs, **full_keys) 38 | kernels.append((build(name, code, runtime_cls), full_keys)) 39 | 40 | # TODO: fix tuning with space > 1 41 | best_runtime, best_time, best_keys = None, None, None 42 | for runtime, tuned_keys in kernels: 43 | if len(space) > 1: 44 | # Check kernel validity 45 | return_code = runtime(**tuned_keys, **kwargs) 46 | if return_code != cbd.CUresult.CUDA_SUCCESS: 47 | # Pass illegal kernels, e.g., insufficient shared memory capacity 48 | if int(os.getenv('DG_JIT_DEBUG', 0)): 49 | print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') 50 | continue 51 | 52 | # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels 53 | start_event = torch.cuda.Event(enable_timing=True) 54 | end_event = torch.cuda.Event(enable_timing=True) 55 | torch.empty(int(256e6 // 4), dtype=torch.int, 56 | device='cuda').zero_() 57 | torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn( 58 | (8192, 8192), dtype=torch.float, device='cuda') 59 | start_event.record() 60 | for i in range(20): 61 | assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS 62 | end_event.record() 63 | end_event.synchronize() 64 | elapsed_time = start_event.elapsed_time(end_event) 65 | else: 66 | elapsed_time = 0 67 | 68 | # Compare if better 69 | if best_time is None or elapsed_time < best_time: 70 | best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys 71 | if int(os.getenv('DG_JIT_DEBUG', 0)): 72 | print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') 73 | assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' 74 | 75 | # Cache the best runtime and return 76 | if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)): 77 | print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') 78 | self.tuned[signature] = (best_runtime, best_keys) 79 | return best_runtime, best_keys 80 | 81 | 82 | jit_tuner = JITTuner() 83 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | _num_sms = None 4 | 5 | 6 | def set_num_sms(num_sms: int) -> None: 7 | """ 8 | Set the maximum SM count for all GEMM kernels to use. 9 | 10 | Arguments: 11 | num_sms: the desired maximum SM count for all GEMM kernels to use. 12 | """ 13 | global _num_sms 14 | assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count 15 | _num_sms = num_sms 16 | 17 | 18 | def get_num_sms() -> int: 19 | """ 20 | Get the current maximum limit of SM count for all GEMM kernels to use. 21 | If the count is never specified, the function will return the number of device SMs. 22 | 23 | Returns: 24 | Current maximum limit of SM count for all GEMM kernels to use. 25 | """ 26 | global _num_sms 27 | if _num_sms is None: 28 | _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count 29 | return _num_sms 30 | 31 | 32 | def ceil_div(x: int, y: int) -> int: 33 | """ 34 | Perform ceiling division of two integers. 35 | 36 | Args: 37 | x: the dividend. 38 | y: the divisor. 39 | 40 | Returns: 41 | The result of the ceiling division. 42 | """ 43 | return (x + y - 1) // y 44 | 45 | 46 | def get_m_alignment_for_contiguous_layout(): 47 | """ 48 | When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. 49 | Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well 50 | with GEMM block shape. 51 | 52 | Returns: 53 | Group-level alignment requirement for grouped contiguous layout, which is always 128. 54 | """ 55 | return 128 56 | 57 | 58 | def get_tma_aligned_size(x: int, element_size: int) -> int: 59 | """ 60 | Global memory address of TMA must be 16-byte aligned. 61 | Since we use column-major layout for the LHS scaling tensor, 62 | the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. 63 | 64 | Arguments: 65 | x: original M-axis shape of the LHS scaling tensor. 66 | element_size: element size of the LHS scaling tensor. 67 | 68 | Returns: 69 | M-axis shape of the LHS scaling tensor after padding. 70 | """ 71 | tma_alignment_bytes = 16 72 | assert tma_alignment_bytes % element_size == 0 73 | alignment = tma_alignment_bytes // element_size 74 | return ceil_div(x, alignment) * alignment 75 | 76 | 77 | def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. 80 | If the input tensor is already column-major layout and 16-byte aligned along the M axis 81 | (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. 82 | 83 | Arguments: 84 | x: usually the LHS scaling tensor in GEMM. 85 | 86 | Returns: 87 | The LHS scaling tensor of TMA-aligned transposed format. 88 | """ 89 | # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA 90 | assert x.dim() in (2, 3) 91 | remove_dim = False 92 | m, n = x.shape[-2], x.shape[-1] 93 | aligned_m = get_tma_aligned_size(m, x.element_size()) 94 | if x.dim() == 2: 95 | if x.stride(0) == 1 and x.stride(1) == aligned_m: 96 | return x 97 | x, remove_dim = x.unsqueeze(0), True 98 | 99 | b = x.shape[0] 100 | 101 | # The last kernel gives a column-major TMA aligned layout 102 | if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: 103 | return x.squeeze(0) if remove_dim else x 104 | 105 | # Normal layout requires transposing 106 | aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) 107 | aligned_x[:, :m, :] = x 108 | aligned_x = aligned_x[:, :m, :] 109 | return aligned_x.squeeze(0) if remove_dim else aligned_x 110 | -------------------------------------------------------------------------------- /deep_gemm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def bench(fn, num_warmups: int = 5, num_tests: int = 10, 9 | high_precision: bool = False): 10 | # Flush L2 cache with 256 MB data 11 | torch.cuda.synchronize() 12 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') 13 | cache.zero_() 14 | 15 | # Warmup 16 | for _ in range(num_warmups): 17 | fn() 18 | 19 | # Add a large kernel to eliminate the CPU launch overhead 20 | if high_precision: 21 | x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 22 | y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 23 | x @ y 24 | 25 | # Testing 26 | start_event = torch.cuda.Event(enable_timing=True) 27 | end_event = torch.cuda.Event(enable_timing=True) 28 | start_event.record() 29 | for i in range(num_tests): 30 | fn() 31 | end_event.record() 32 | torch.cuda.synchronize() 33 | 34 | return start_event.elapsed_time(end_event) / num_tests 35 | 36 | 37 | class empty_suppress: 38 | def __enter__(self): 39 | return self 40 | 41 | def __exit__(self, *_): 42 | pass 43 | 44 | 45 | class suppress_stdout_stderr: 46 | def __enter__(self): 47 | self.outnull_file = open(os.devnull, 'w') 48 | self.errnull_file = open(os.devnull, 'w') 49 | 50 | self.old_stdout_fileno_undup = sys.stdout.fileno() 51 | self.old_stderr_fileno_undup = sys.stderr.fileno() 52 | 53 | self.old_stdout_fileno = os.dup(sys.stdout.fileno()) 54 | self.old_stderr_fileno = os.dup(sys.stderr.fileno()) 55 | 56 | self.old_stdout = sys.stdout 57 | self.old_stderr = sys.stderr 58 | 59 | os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) 60 | os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) 61 | 62 | sys.stdout = self.outnull_file 63 | sys.stderr = self.errnull_file 64 | return self 65 | 66 | def __exit__(self, *_): 67 | sys.stdout = self.old_stdout 68 | sys.stderr = self.old_stderr 69 | 70 | os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) 71 | os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 72 | 73 | os.close(self.old_stdout_fileno) 74 | os.close(self.old_stderr_fileno) 75 | 76 | self.outnull_file.close() 77 | self.errnull_file.close() 78 | 79 | 80 | def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, 81 | trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): 82 | # Conflict with Nsight Systems 83 | using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) 84 | 85 | # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle 86 | flush_l2_size = int(8e9 // 4) 87 | 88 | # For some auto-tuning kernels with prints 89 | fn() 90 | 91 | # Profile 92 | suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress 93 | with suppress(): 94 | schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None 95 | profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() 96 | with profiler: 97 | for i in range(2): 98 | # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead 99 | if barrier_comm_profiling: 100 | lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 101 | rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 102 | lhs @ rhs 103 | dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) 104 | for _ in range(num_tests): 105 | if flush_l2: 106 | torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() 107 | fn() 108 | 109 | if not using_nsys: 110 | profiler.step() 111 | 112 | # Return 1 if using Nsight Systems 113 | if using_nsys: 114 | return 1 115 | 116 | # Parse the profiling table 117 | assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) 118 | is_tupled = isinstance(kernel_names, tuple) 119 | prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') 120 | kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names 121 | assert all([isinstance(name, str) for name in kernel_names]) 122 | for name in kernel_names: 123 | assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' 124 | 125 | # Save chrome traces 126 | if trace_path is not None: 127 | profiler.export_chrome_trace(trace_path) 128 | 129 | # Return average kernel times 130 | units = {'ms': 1e3, 'us': 1e6} 131 | kernel_times = [] 132 | for name in kernel_names: 133 | for line in prof_lines: 134 | if name in line: 135 | time_str = line.split()[-2] 136 | for unit, scale in units.items(): 137 | if unit in time_str: 138 | kernel_times.append(float(time_str.replace(unit, '')) / scale) 139 | break 140 | break 141 | return tuple(kernel_times) if is_tupled else kernel_times[0] 142 | 143 | 144 | def calc_diff(x, y): 145 | x, y = x.double(), y.double() 146 | denominator = (x * x + y * y).sum() 147 | sim = 2 * (x * y).sum() / denominator 148 | return 1 - sim 149 | 150 | 151 | def count_bytes(tensors): 152 | total = 0 153 | for t in tensors: 154 | if isinstance(t, tuple): 155 | total += count_bytes(t) 156 | else: 157 | total += t.numel() * t.element_size() 158 | return total 159 | -------------------------------------------------------------------------------- /figures/design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgl-project/DeepGEMM/d75b218b7b8f4a5dd5406ac87905039ead3ae42f/figures/design.png -------------------------------------------------------------------------------- /indexing/main.cu: -------------------------------------------------------------------------------- 1 | #include "deep_gemm/fp8_gemm.cuh" 2 | 3 | using namespace deep_gemm; 4 | 5 | int main() { 6 | int m = 128; 7 | constexpr int N = 4096; 8 | constexpr int K = 7168; 9 | 10 | constexpr int BLOCK_M = 128; 11 | constexpr int BLOCK_N = 128; 12 | constexpr int BLOCK_K = 128; 13 | constexpr int BLOCK_N_PADDING = 0; 14 | constexpr int kSwizzleDMode = 0; 15 | constexpr int kNumGroups = 1; 16 | constexpr int kNumStages = 5; 17 | constexpr int kNumTMAMulticast = 1; 18 | constexpr bool kIsTMAMulticastOnA = false; 19 | 20 | using gemm_t = Gemm; 21 | auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m); 22 | auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0)); 23 | auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast(0), m); 24 | auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast(0), m); 25 | gemm_t::run(nullptr, nullptr, nullptr, 26 | m, 27 | tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, 28 | nullptr, 132, 0); 29 | return 0; 30 | } 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | import shutil 4 | import subprocess 5 | from setuptools.command.build_py import build_py 6 | from setuptools.command.develop import develop 7 | 8 | current_dir = os.path.dirname(os.path.realpath(__file__)) 9 | jit_include_dirs = ('deep_gemm/include/deep_gemm', ) 10 | third_party_include_dirs = ( 11 | 'third-party/cutlass/include/cute', 12 | 'third-party/cutlass/include/cutlass', 13 | ) 14 | 15 | 16 | class PostDevelopCommand(develop): 17 | def run(self): 18 | develop.run(self) 19 | self.make_jit_include_symlinks() 20 | 21 | @staticmethod 22 | def make_jit_include_symlinks(): 23 | # Make symbolic links of third-party include directories 24 | for d in third_party_include_dirs: 25 | dirname = d.split('/')[-1] 26 | src_dir = f'{current_dir}/{d}' 27 | dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' 28 | assert os.path.exists(src_dir) 29 | if os.path.exists(dst_dir): 30 | assert os.path.islink(dst_dir) 31 | os.unlink(dst_dir) 32 | os.symlink(src_dir, dst_dir, target_is_directory=True) 33 | 34 | 35 | class CustomBuildPy(build_py): 36 | def run(self): 37 | # First, prepare the include directories 38 | self.prepare_includes() 39 | 40 | # Then run the regular build 41 | build_py.run(self) 42 | 43 | def prepare_includes(self): 44 | # Create temporary build directory instead of modifying package directory 45 | build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') 46 | os.makedirs(build_include_dir, exist_ok=True) 47 | 48 | # Copy third-party includes to the build directory 49 | for d in third_party_include_dirs: 50 | dirname = d.split('/')[-1] 51 | src_dir = os.path.join(current_dir, d) 52 | dst_dir = os.path.join(build_include_dir, dirname) 53 | 54 | # Remove existing directory if it exists 55 | if os.path.exists(dst_dir): 56 | shutil.rmtree(dst_dir) 57 | 58 | # Copy the directory 59 | shutil.copytree(src_dir, dst_dir) 60 | 61 | 62 | if __name__ == '__main__': 63 | # noinspection PyBroadException 64 | try: 65 | cmd = ['git', 'rev-parse', '--short', 'HEAD'] 66 | revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() 67 | except: 68 | revision = '' 69 | 70 | setuptools.setup( 71 | name='deep_gemm', 72 | version='1.0.0' + revision, 73 | packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], 74 | package_data={ 75 | 'deep_gemm': [ 76 | 'include/deep_gemm/*', 77 | 'include/cute/**/*', 78 | 'include/cutlass/**/*', 79 | ] 80 | }, 81 | cmdclass={ 82 | 'develop': PostDevelopCommand, 83 | 'build_py': CustomBuildPy, 84 | }, 85 | ) 86 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | # PyTorch has its own NVRTC, which may have a lower version than the system 2 | # So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch 3 | import cuda.bindings.nvrtc as nvrtc 4 | print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') 5 | 6 | import random 7 | import torch 8 | from typing import Tuple 9 | 10 | import deep_gemm 11 | from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor 12 | from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout 13 | 14 | 15 | def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 16 | assert x.dim() == 2 and x.size(1) % 128 == 0 17 | m, n = x.shape 18 | x_view = x.view(m, -1, 128) 19 | x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) 20 | return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) 21 | 22 | 23 | def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 24 | assert x.dim() == 2 25 | m, n = x.shape 26 | x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) 27 | x_padded[:m, :n] = x 28 | x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) 29 | x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) 30 | x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) 31 | return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) 32 | 33 | 34 | def construct(m: int, k: int, n: int) -> \ 35 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 36 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 37 | y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) 38 | out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) 39 | ref_out = x @ y.t() 40 | 41 | x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) 42 | # Transpose earlier so that the testing will not trigger transposing kernels 43 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 44 | return x_fp8, y_fp8, out, ref_out 45 | 46 | 47 | def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ 48 | Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: 49 | m = 0 50 | m_aligned = get_m_alignment_for_contiguous_layout() 51 | group_m_list = [] 52 | for i in range(num_groups): 53 | group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned) 54 | m += group_m 55 | group_m_list.append(group_m) 56 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 57 | y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) 58 | m_indices = torch.empty(m, device='cuda', dtype=torch.int32) 59 | out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) 60 | ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) 61 | 62 | start = 0 63 | for i, group_m in enumerate(group_m_list): 64 | end = start + group_m 65 | m_indices[start:end] = i 66 | ref_out[start:end] = x[start:end] @ y[i].t() 67 | start = end 68 | 69 | assert m % 4 == 0, f'TMA alignment error: {m}' 70 | x_fp8 = per_token_cast_to_fp8(x) 71 | y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) 72 | for i in range(num_groups): 73 | y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) 74 | 75 | return m, x_fp8, y_fp8, m_indices, out, ref_out 76 | 77 | 78 | def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ 79 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 80 | x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) 81 | y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) 82 | out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) 83 | ref_out = torch.einsum('gmk,gnk->gmn', x, y) 84 | 85 | assert m % 4 == 0, f'TMA alignment error: {m}' 86 | x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) 87 | y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) 88 | for i in range(num_groups): 89 | x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) 90 | y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) 91 | 92 | # Transpose earlier so that the testing will not trigger transposing kernels 93 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 94 | return x_fp8, y_fp8, out, ref_out 95 | 96 | 97 | def test_gemm() -> None: 98 | print('Testing GEMM:') 99 | for m in (64, 128, 4096): 100 | for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: 101 | x_fp8, y_fp8, out, ref_out = construct(m, k, n) 102 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 103 | diff = calc_diff(out, ref_out) 104 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' 105 | 106 | # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) 107 | x_fp8, y_fp8, out, ref_out = construct(m, k, n) 108 | 109 | # noinspection PyShadowingNames 110 | def test_func(): 111 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 112 | 113 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 114 | print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' 115 | f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' 116 | f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') 117 | print() 118 | 119 | 120 | def test_m_grouped_gemm_contiguous() -> None: 121 | print('Testing grouped contiguous GEMM:') 122 | 123 | for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): 124 | # TODO: make a stronger test 125 | m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) 126 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 127 | diff = calc_diff(out, ref_out) 128 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' 129 | 130 | # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) 131 | m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) 132 | 133 | # noinspection PyShadowingNames 134 | def test_func(): 135 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 136 | 137 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 138 | print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 139 | f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' 140 | f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') 141 | print() 142 | 143 | 144 | def test_m_grouped_gemm_masked() -> None: 145 | print('Testing grouped masked GEMM:') 146 | 147 | for num_groups, m in ((1, 1024), (2, 512), (4, 256)): 148 | for k, n in ((7168, 4096), (2048, 7168), ): 149 | # Test correctness 150 | masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) 151 | for i in range(10): 152 | x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) 153 | masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) 154 | for j in range(num_groups): 155 | masked_m[j] = random.choice(masked_m_candidates) 156 | expected_m = min(int(masked_m.float().mean()) + 1, m) 157 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) 158 | for j in range(num_groups): 159 | diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) 160 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' 161 | 162 | # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) 163 | x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) 164 | masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m 165 | 166 | # noinspection PyShadowingNames 167 | def test_func(): 168 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m) 169 | 170 | # Test performance with fixed shapes 171 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 172 | print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 173 | f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' 174 | f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') 175 | print() 176 | 177 | 178 | if __name__ == '__main__': 179 | torch.backends.cuda.matmul.allow_tf32 = True 180 | torch.backends.cudnn.allow_tf32 = True 181 | torch.manual_seed(0) 182 | random.seed(0) 183 | 184 | print('Library path:') 185 | print(f' > {deep_gemm.__path__}\n') 186 | 187 | test_gemm() 188 | test_m_grouped_gemm_contiguous() 189 | test_m_grouped_gemm_masked() 190 | -------------------------------------------------------------------------------- /tests/test_jit.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import torch 4 | import cuda.bindings.driver as cbd 5 | 6 | from deep_gemm import jit 7 | 8 | # Essential debugging staffs 9 | os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') 10 | os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') 11 | 12 | 13 | class VectorAddRuntime(jit.Runtime): 14 | def __init__(self, path: str) -> None: 15 | super().__init__(path, [ 16 | 'A', 17 | 'B', 18 | 'C', 19 | 'STREAM', 20 | ]) 21 | 22 | @staticmethod 23 | def generate(**kwargs) -> str: 24 | return f""" 25 | #ifdef __CUDACC_RTC__ 26 | #include 27 | #else 28 | #include 29 | #endif 30 | 31 | #include 32 | #include 33 | 34 | template 35 | __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ 36 | uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; 37 | if (i < n) {{ 38 | c[i] = a[i] + b[i]; 39 | }} 40 | }} 41 | 42 | static void __instantiate_kernel() {{ 43 | auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); 44 | }} 45 | """ 46 | 47 | # noinspection PyShadowingNames,PyMethodOverriding 48 | @staticmethod 49 | def launch(kernel: cbd.CUkernel, 50 | a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, 51 | stream: cbd.CUstream) -> cbd.CUresult: 52 | assert a.shape == b.shape == c.shape 53 | assert a.device == b.device == c.device 54 | assert a.dim() == 1 55 | 56 | config = cbd.CUlaunchConfig() 57 | config.gridDimX = (a.numel() + 127) // 128 58 | config.gridDimY = 1 59 | config.gridDimZ = 1 60 | config.blockDimX = 128 61 | config.blockDimY = 1 62 | config.blockDimZ = 1 63 | config.hStream = stream 64 | 65 | arg_values = ( 66 | a.data_ptr(), 67 | b.data_ptr(), 68 | c.data_ptr(), 69 | a.numel(), 70 | ) 71 | arg_types = ( 72 | ctypes.c_void_p, 73 | ctypes.c_void_p, 74 | ctypes.c_void_p, 75 | ctypes.c_uint32, 76 | ) 77 | 78 | return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] 79 | 80 | 81 | if __name__ == '__main__': 82 | print('Generated code:') 83 | code = VectorAddRuntime.generate(T='float') 84 | print(code) 85 | print() 86 | 87 | for compiler_name in ('NVCC', 'NVRTC'): 88 | # Get compiler 89 | compiler_cls = getattr(jit, f'{compiler_name}Compiler') 90 | print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') 91 | 92 | # Build 93 | print('Building ...') 94 | func = compiler_cls.build('test_func', code, VectorAddRuntime) 95 | 96 | # Run and check 97 | a = torch.randn((1024, ), dtype=torch.float32, device='cuda') 98 | b = torch.randn((1024, ), dtype=torch.float32, device='cuda') 99 | c = torch.empty_like(a) 100 | ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) 101 | assert ret == cbd.CUresult.CUDA_SUCCESS, ret 102 | torch.testing.assert_close(c, a + b) 103 | print(f'JIT test for {compiler_name} passed\n') 104 | --------------------------------------------------------------------------------