├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── deep_gemm ├── __init__.py ├── include │ └── deep_gemm │ │ ├── fp8_gemm.cuh │ │ ├── fp8_wgrad_gemm.cuh │ │ ├── mma_utils.cuh │ │ ├── nvrtc_std.cuh │ │ ├── scheduler.cuh │ │ ├── tma_utils.cuh │ │ └── utils.cuh ├── jit │ ├── __init__.py │ ├── compiler.py │ ├── interleave_ffma.py │ └── runtime.py ├── jit_kernels │ ├── __init__.py │ ├── gemm.py │ ├── m_grouped_gemm.py │ ├── runtime.py │ ├── utils.py │ └── wgrad_gemm.py └── utils.py ├── figures └── design.png ├── indexing └── main.cu ├── setup.py └── tests ├── test_core.py └── test_jit.py /.gitignore: -------------------------------------------------------------------------------- 1 | cmake-build-* 2 | .idea 3 | .DS_Store 4 | build 5 | dist 6 | *.egg-info 7 | *.pyc 8 | 9 | # Third-party links created by `setup.py develop` 10 | deep_gemm/include/cute 11 | deep_gemm/include/cutlass 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/cutlass"] 2 | path = third-party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT 2 | # TODO: add CUDA utils' library via CMake 3 | cmake_minimum_required(VERSION 3.10) 4 | project(deep_gemm LANGUAGES CXX CUDA) 5 | 6 | set(CMAKE_CXX_STANDARD 20) 7 | set(CMAKE_CUDA_STANDARD 20) 8 | set(CMAKE_VERBOSE_MAKEFILE ON) 9 | 10 | find_package(CUDAToolkit REQUIRED) 11 | find_package(pybind11 REQUIRED) 12 | 13 | file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }") 14 | execute_process( 15 | COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu 16 | RESULT_VARIABLE NVCC_RESULT 17 | OUTPUT_VARIABLE NVCC_OUTPUT 18 | ERROR_VARIABLE NVCC_ERROR_OUTPUT 19 | WORKING_DIRECTORY ${CMAKE_BINARY_DIR} 20 | ) 21 | 22 | if (NVCC_RESULT EQUAL "0") 23 | set(NVCC_SUPPORTS_SM90 TRUE) 24 | message(STATUS "NVCC supports SM90") 25 | else() 26 | message(STATUS "NVCC does not support SM90") 27 | endif() 28 | 29 | if (NVCC_SUPPORTS_SM90) 30 | set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE) 31 | list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a") 32 | endif() 33 | find_package(Torch REQUIRED) 34 | 35 | include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include) 36 | include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) 37 | link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib) 38 | 39 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") 40 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") 41 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG") 42 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10") 43 | 44 | cuda_add_library(example_gemm STATIC indexing/main.cu) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGEMM 2 | 3 | DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. 4 | 5 | Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. 6 | 7 | Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. 8 | 9 | ## News 10 | 11 | - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. 12 | - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). 13 | - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. 14 | 15 | ## Roadmap 16 | 17 | - [x] More correctness tests for grouped-contiguous layout 18 | - [x] Shared memory swizzling for output 19 | - [ ] Larger block size on N (up to 256) 20 | - [x] MoE scheduler with TMA multicast compatibility 21 | - [x] Fix TMA multicast compatibility for indivisible shapes 22 | - [x] Skip useless computation on M 23 | - [x] NVRTC as a faster compiler 24 | - [ ] Stolen JIT cache 25 | - [ ] Sanitizer for testing 26 | - [x] Weight gradient kernels for dense models 27 | - [x] Weight gradient kernels for MoE models 28 | - [ ] Better `get_best_configs` modeling 29 | - [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang)) 30 | - [ ] CUDA PDL support 31 | - [ ] More scaling granularity support via templates 32 | - [ ] Larger TMA multicast size for some shapes 33 | - [x] MMA template refactor with CUTLASS 34 | - [ ] Optimizations for power efficiency 35 | - [x] Remove shape limitations on N and K 36 | - [ ] BF16 kernels 37 | - [ ] Split/stream-k optimizations 38 | 39 | ## Quick start 40 | 41 | ### Requirements 42 | 43 | - Hopper architecture GPUs, `sm_90a` must be supported 44 | - Python 3.8 or above 45 | - CUDA 12.3 or above 46 | - **But we highly recommend 12.8 or above for the best performance** 47 | - PyTorch 2.1 or above 48 | - CUTLASS 3.6 or above (could be cloned by Git submodule) 49 | 50 | ### Development 51 | 52 | ```bash 53 | # Submodule must be cloned 54 | git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git 55 | 56 | # Make symbolic links for third-party (CUTLASS and CuTe) include directories 57 | python setup.py develop 58 | 59 | # Test JIT compilation 60 | python tests/test_jit.py 61 | 62 | # Test all GEMM implements (normal, contiguous-grouped and masked-grouped) 63 | python tests/test_core.py 64 | ``` 65 | 66 | ### Installation 67 | 68 | ```bash 69 | python setup.py install 70 | ``` 71 | 72 | Then, import `deep_gemm` in your Python project, and enjoy! 73 | 74 | ## Interfaces 75 | 76 | #### Notices 77 | 78 | This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. 79 | 80 | #### Normal dense GEMMs (non-grouped) 81 | 82 | To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation. 83 | 84 | #### Grouped GEMMs (contiguous layout) 85 | 86 | Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. 87 | 88 | For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`). 89 | 90 | For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation. 91 | 92 | #### Grouped GEMMs (masked layout) 93 | 94 | During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. 95 | 96 | Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. 97 | 98 | #### Utilities 99 | 100 | The library provides some utility functions besides the above kernels: 101 | 102 | - `deep_gemm.set_num_sms`: set the maximum SM count to use 103 | - `deep_gemm.get_num_sms`: get the current SM maximum count 104 | - `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout 105 | - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size 106 | - `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor 107 | 108 | The library also provides some environment variables, which may be useful: 109 | 110 | - General 111 | - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default 112 | - JIT cache related 113 | - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default 114 | - `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default 115 | - NVCC/NVRTC selections 116 | - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default 117 | - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default 118 | - Compiler options 119 | - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default 120 | - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default 121 | - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default 122 | - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default 123 | - Post optimization 124 | - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default 125 | - Heuristic selection 126 | - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default 127 | - Testing 128 | - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default 129 | 130 | For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. 131 | 132 | ## Optimizations 133 | 134 | We indicate the techniques excluded from CUTLASS with 🐳. 135 | 136 | #### Persistent warp-specialization 137 | 138 | Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below: 139 | 140 | ![design](figures/design.png) 141 | 142 | #### Hopper TMA features 143 | 144 | The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for: 145 | 146 | - TMA load for LHS, LHS scaling factors, and RHS matrices 147 | - TMA store for the output matrix 148 | - TMA multicast (automatically decide LHS or RHS to broadcast) 149 | - TMA descriptor prefetching 150 | 151 | #### Common detail optimizations 152 | 153 | - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction 154 | - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups 155 | - Less bank conflicts via 3D TMA or swizzling 156 | - Larger block sizes (up to 256x128 🐳) 157 | - Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 158 | 159 | #### A unified and optimized block scheduler 160 | 161 | - [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels 162 | - [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse 163 | 164 | #### Fully JIT design 🐳 165 | 166 | DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages: 167 | 168 | - GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants 169 | - Saving registers 170 | - Compilers may do more optimizations 171 | - Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size 172 | - But without auto-tuning, the optimal one is deterministically selected 173 | - Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities 174 | - Very important for small shapes 175 | - Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details 176 | 177 | Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler. 178 | 179 | #### Unaligned block sizes 🐳 180 | 181 | For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains. 182 | 183 | #### FFMA SASS interleaving 🐳 184 | 185 | We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern. 186 | After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/96a9f72baf00f40b9b299653fcef8d3e2b4a3d49/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work). 187 | 188 | To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions. 189 | 190 | ## Acknowledgement 191 | 192 | DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! 193 | 194 | ## License 195 | 196 | This code repository is released under [the MIT License](LICENSE). 197 | 198 | ## Citation 199 | 200 | ```bibtex 201 | @misc{deepgemm2025, 202 | title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling}, 203 | author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu}, 204 | year={2025}, 205 | publisher = {GitHub}, 206 | howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, 207 | } 208 | ``` 209 | -------------------------------------------------------------------------------- /deep_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import jit 4 | from .jit_kernels import ( 5 | gemm_fp8_fp8_bf16_nt, 6 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 7 | m_grouped_gemm_fp8_fp8_bf16_nt_masked, 8 | wgrad_gemm_fp8_fp8_fp32_nt, 9 | k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, 10 | ceil_div, 11 | set_num_sms, get_num_sms, 12 | get_col_major_tma_aligned_tensor, 13 | get_m_alignment_for_contiguous_layout 14 | ) 15 | from .utils import bench, bench_kineto, calc_diff 16 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/fp8_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #pragma clang diagnostic push 4 | #pragma clang diagnostic ignored "-Wunknown-attributes" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "mma_utils.cuh" 14 | #include "scheduler.cuh" 15 | #include "tma_utils.cuh" 16 | #include "utils.cuh" 17 | 18 | namespace deep_gemm { 19 | 20 | template 21 | __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { 22 | if (num_former_iters == kNumFormerIters) { 23 | inner_launch_k_iterations(func, cute::Int{}); 24 | return; 25 | } 26 | 27 | if constexpr (kNumFormerIters + kGap <= kEnd) 28 | outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); 29 | } 30 | 31 | template 39 | __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) 40 | fp8_gemm_kernel(float* scales_b, int* grouped_layout, 41 | uint32_t shape_m, 42 | const __grid_constant__ CUtensorMap tensor_map_a, 43 | const __grid_constant__ CUtensorMap tensor_map_b, 44 | const __grid_constant__ CUtensorMap tensor_map_scales_a, 45 | const __grid_constant__ CUtensorMap tensor_map_d) { 46 | #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) 47 | // Scaling checks 48 | DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); 49 | DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); 50 | 51 | // Types 52 | using WGMMA = typename FP8MMASelector::type; 53 | using Barrier = cutlass::arch::ClusterTransactionBarrier; 54 | DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); 55 | 56 | // Shared memory 57 | static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); 58 | static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); 59 | static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); 60 | static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); 61 | static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); 62 | static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); 63 | static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); 64 | 65 | // Configs 66 | constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; 67 | constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); 68 | constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; 69 | constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); 70 | const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); 71 | const uint32_t lane_idx = get_lane_id(); 72 | 73 | // Prefetch TMA descriptors at the very beginning 74 | if (threadIdx.x == kNumMathThreads) { 75 | // NOTES: `reinterpret_cast` must be here, or NVRTC will fail 76 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); 77 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); 78 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); 79 | 80 | // `tensor_map_d` is only used in swizzling mode 81 | // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode 82 | if constexpr (kSwizzleDMode > 0) 83 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); 84 | } 85 | __syncwarp(); 86 | 87 | // Align to 1024 bytes for swizzle-128B 88 | extern __shared__ __align__(1024) uint8_t smem_buffer[]; 89 | DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); 90 | 91 | // Data on shared memory 92 | auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); 93 | __nv_fp8_e4m3* smem_a[kNumStages]; 94 | __nv_fp8_e4m3* smem_b[kNumStages]; 95 | float* smem_scales_a[kNumStages]; 96 | float* smem_scales_b; 97 | 98 | // TMA Barrier for both divisible and non-divisible cases 99 | Barrier* full_barriers[kNumStages]; 100 | Barrier* empty_barriers[kNumStages]; 101 | 102 | // Fill shared memory pointers 103 | #pragma unroll 104 | for (uint32_t i = 0; i < kNumStages; ++ i) { 105 | smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); 106 | smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); 107 | smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); 108 | } 109 | smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); 110 | 111 | // Fill barriers 112 | auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); 113 | #pragma unroll 114 | for (uint32_t i = 0; i < kNumStages; ++ i) { 115 | full_barriers[i] = barrier_start_ptr + i; 116 | empty_barriers[i] = barrier_start_ptr + kNumStages + i; 117 | } 118 | 119 | // Initialize barriers 120 | DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); 121 | if (threadIdx.x == kNumMathThreads) { 122 | // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, 123 | // even with TMA multicast disabled, we want to make the behavior aligned 124 | #pragma unroll 125 | for (uint32_t i = 0; i < kNumStages; ++ i) { 126 | full_barriers[i]->init(1); 127 | empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); 128 | } 129 | 130 | // Make initialized barrier visible in async proxy 131 | cutlass::arch::fence_view_async_shared(); 132 | (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); 133 | } 134 | 135 | // Synchronize all threads to make barrier visible in normal memory model 136 | (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); 137 | 138 | // For pipeline unrolling 139 | struct DivisibleK {}; 140 | struct NotDivisibleK {}; 141 | struct SkipComputation {}; 142 | struct NotSkipComputation {}; 143 | auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) { 144 | constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; 145 | constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; 146 | constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; 147 | 148 | // NOTES: for too-many branches (> 5), we disable this optimization 149 | // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value 150 | outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { 151 | if (skip_computation) { 152 | for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) 153 | func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); 154 | } else if (SHAPE_K % kFullKOfAllStages == 0) { 155 | for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) 156 | func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); 157 | } else { 158 | for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) 159 | func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); 160 | func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); 161 | } 162 | }, func, kShouldOptimize ? num_former_iters : 0); 163 | }; 164 | 165 | // Register reconfigurations 166 | constexpr uint32_t kNumTMARegisters = 40; 167 | constexpr uint32_t kNumMathRegisters = 232; 168 | 169 | // Block scheduler 170 | uint32_t m_block_idx, n_block_idx; 171 | auto scheduler = Scheduler(shape_m, grouped_layout); 172 | 173 | if (threadIdx.x >= kNumMathThreads) { 174 | // TMA warp-group for loading data 175 | cutlass::arch::warpgroup_reg_dealloc(); 176 | 177 | // NOTES: only one thread (or warp) will be used 178 | if (threadIdx.x == kNumMathThreads) { 179 | // Persistently schedule over blocks 180 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 181 | launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { 182 | constexpr bool kHasDivisibleStages = std::is_same_v; 183 | constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; 184 | 185 | // Assign TMA multicast number into A and B 186 | // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. 187 | const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); 188 | const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 189 | const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 190 | DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); 191 | 192 | // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all 193 | // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant 194 | #pragma unroll 195 | for (uint32_t s = 0; s < kNumInnerStages; ++ s) { 196 | // Wait consumer release 197 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); 198 | 199 | // Issue TMA A 200 | auto& full_barrier = *full_barriers[s]; 201 | uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; 202 | tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), 203 | smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), 204 | num_tma_multicast_a); 205 | tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), 206 | smem_scales_a[s], m_block_idx * BLOCK_M, 207 | scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), 208 | num_tma_multicast_a); 209 | 210 | // Issue TMA B 211 | tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), 212 | smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), 213 | num_tma_multicast_b); 214 | full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); 215 | } 216 | 217 | // Wait unaligned cases 218 | #pragma unroll 219 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 220 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); 221 | full_barriers[s]->arrive(); 222 | } 223 | }, false, 0); 224 | } 225 | 226 | // To safely deconstruct distributed shared barriers, we need another round of empty waits 227 | if constexpr (kNumTMAMulticast > 1) { 228 | #pragma unroll 229 | for (uint32_t s = 0; s < kNumStages; ++ s) 230 | empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); 231 | } 232 | } 233 | } else { 234 | // Math warp-groups for WGMMA 235 | cutlass::arch::warpgroup_reg_alloc(); 236 | 237 | // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers 238 | const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); 239 | const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; 240 | 241 | // Persistently schedule over blocks 242 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 243 | // Decide the number of scales B to load 244 | DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); 245 | uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; 246 | if constexpr (not kMustUseUniformedScaleB) { 247 | num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; 248 | num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; 249 | } 250 | uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); 251 | 252 | // Load B scales with math warp-groups 253 | // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks 254 | if (threadIdx.x >= 32) { 255 | auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); 256 | auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; 257 | #pragma unroll 258 | for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) 259 | st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); 260 | } 261 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 262 | 263 | // Accumulation for WGMMA or CUDA promotion 264 | constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); 265 | DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); 266 | float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; 267 | 268 | // Empty barrier arrival 269 | auto empty_barrier_arrive = [&](uint32_t s) { 270 | if constexpr (kNumTMAMulticast == 1) { 271 | lane_idx == 0 ? empty_barriers[s]->arrive() : void(); 272 | } else { 273 | auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); 274 | lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); 275 | } 276 | }; 277 | 278 | // Launch MMAs 279 | launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { 280 | constexpr bool kSkipComputation = std::is_same_v; 281 | constexpr bool kHasDivisibleStages = std::is_same_v; 282 | constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : 283 | (kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K); 284 | 285 | #pragma unroll 286 | for (uint32_t s = 0; s < kNumInnerStages; ++ s) { 287 | // Read B scales 288 | float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; 289 | // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks 290 | if constexpr (not kMustUseUniformedScaleB) 291 | scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); 292 | 293 | // Wait TMA arrivals 294 | full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); 295 | 296 | // TODO: remove some useless computation for unaligned Ms 297 | #pragma unroll 298 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 299 | auto m_offset = local_idx * WAVE_BLOCK_M; 300 | 301 | // Read A scales 302 | // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results 303 | auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); 304 | auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); 305 | 306 | // Commit WGMMA instructions 307 | #pragma unroll 308 | for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) 309 | warpgroup_fence_operand(accum[i]); 310 | warpgroup_arrive(); 311 | #pragma unroll 312 | for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { 313 | auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); 314 | auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); 315 | WGMMA::wgmma(desc_a, desc_b, accum, k); 316 | } 317 | warpgroup_commit_batch(); 318 | #pragma unroll 319 | for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) 320 | warpgroup_fence_operand(accum[i]); 321 | warpgroup_wait<0>(); 322 | 323 | // Notify barrier arrival at the last warpgroup wave 324 | if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) 325 | empty_barrier_arrive(s); 326 | 327 | // Promote with scales 328 | // NOTES: making it as predicates is very important for performance, comparing to two loops 329 | float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; 330 | float scale_0_1, scale_1_1; 331 | if constexpr (not kMustUseUniformedScaleB) 332 | scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; 333 | 334 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 335 | #pragma unroll 336 | for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 337 | // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant 338 | bool predicate = kMustUseUniformedScaleB or i < num_former_iters; 339 | shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; 340 | shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; 341 | shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; 342 | shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; 343 | } 344 | } 345 | } 346 | 347 | // Wait unaligned cases 348 | #pragma unroll 349 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 350 | full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); 351 | empty_barrier_arrive(s); 352 | } 353 | }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); 354 | 355 | // TMA checks 356 | constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); 357 | constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); 358 | constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; 359 | DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); 360 | DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, 361 | "Unaligned TMA store or too many TMA store instructions"); 362 | DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); 363 | DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, 364 | "Swizzling and padding are not compatible"); 365 | 366 | // Wait last TMA store to be finished 367 | if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) 368 | cute::tma_store_wait<0>(); 369 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 370 | 371 | // Write back to shared memory using STSM and issue TMA stores 372 | DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); 373 | #pragma unroll 374 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 375 | auto m_offset = local_idx * WAVE_BLOCK_M; 376 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 377 | #pragma unroll 378 | for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 379 | // Swizzle or padding into the correct address 380 | uint8_t* smem_ptr = nullptr; 381 | if constexpr (kSwizzleDMode > 0) { 382 | // Calculate the swizzling atom offset and in-atom offset 383 | constexpr uint32_t kNumBankGroupBytes = 16; 384 | auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); 385 | 386 | // Calculate the index of the bank group to be written in the atom 387 | auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); 388 | 389 | // Reshape the atom in another view and swizzle 390 | // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` 391 | // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` 392 | constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; 393 | auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); 394 | auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); 395 | col ^= row % (kSwizzleDMode / 16); 396 | 397 | // Add back into the base pointer 398 | // NOTES: think twice before modifying this, as changes may affect the number of instructions 399 | smem_ptr = reinterpret_cast(smem_d) + // Base pointer 400 | warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset 401 | m_offset * kSwizzleDMode + // Wave offset 402 | atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) 403 | row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset 404 | } else { 405 | // No swizzling, just padding 406 | // NOTES: padding must be zero for BF16 output 407 | DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); 408 | smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); 409 | } 410 | 411 | // NOTES: only 16 lanes' addresses are used 412 | SM90_U32x2_STSM_N::copy( 413 | __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), 414 | __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), 415 | smem_ptr 416 | ); 417 | } 418 | } 419 | cute::tma_store_fence(); 420 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 421 | 422 | // Use TMA store to write back to global memory 423 | // TODO: compatible with FP32 output 424 | DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); 425 | if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { 426 | auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; 427 | auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; 428 | cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, 429 | n_block_idx * BLOCK_N + in_block_n_offset, 430 | scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); 431 | cute::tma_store_arrive(); 432 | } 433 | __syncwarp(); 434 | } 435 | } 436 | #else 437 | if (blockIdx.x == 0 and threadIdx.x == 0) 438 | DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); 439 | #endif 440 | } 441 | 442 | }; // namespace deep_gemm 443 | 444 | #pragma clang diagnostic pop -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #pragma clang diagnostic push 4 | #pragma clang diagnostic ignored "-Wunknown-attributes" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "mma_utils.cuh" 14 | #include "scheduler.cuh" 15 | #include "tma_utils.cuh" 16 | #include "utils.cuh" 17 | 18 | namespace deep_gemm { 19 | 20 | template 25 | __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) 26 | fp8_wgrad_gemm_kernel(uint32_t shape_k, 27 | const __grid_constant__ CUtensorMap tensor_map_a, 28 | const __grid_constant__ CUtensorMap tensor_map_b, 29 | const __grid_constant__ CUtensorMap tensor_map_scales_a, 30 | const __grid_constant__ CUtensorMap tensor_map_scales_b, 31 | const __grid_constant__ CUtensorMap tensor_map_d) { 32 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) 33 | // Scaling checks 34 | DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); 35 | 36 | // Types 37 | using WGMMA = typename FP8MMASelector::type; 38 | using Barrier = cutlass::arch::ClusterTransactionBarrier; 39 | DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); 40 | 41 | // Shared memory 42 | static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); 43 | static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); 44 | static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); 45 | static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); 46 | static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); 47 | static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U; 48 | 49 | // Configs 50 | constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; 51 | constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); 52 | constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; 53 | 54 | const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K); 55 | const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); 56 | const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); 57 | const uint32_t lane_idx = get_lane_id(); 58 | 59 | // Prefetch TMA descriptors at the very beginning 60 | if (threadIdx.x == kNumMathThreads) { 61 | // NOTES: `reinterpret_cast` must be here, or NVRTC will fail 62 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); 63 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); 64 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); 65 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); 66 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); 67 | } 68 | __syncwarp(); 69 | 70 | // Align to 1024 bytes for swizzle-128B 71 | extern __shared__ __align__(1024) uint8_t smem_buffer[]; 72 | DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); 73 | 74 | // Data on shared memory 75 | auto smem_d = reinterpret_cast(smem_buffer); 76 | __nv_fp8_e4m3* smem_a[kNumStages]; 77 | __nv_fp8_e4m3* smem_b[kNumStages]; 78 | float* smem_scales_a[kNumStages]; 79 | float* smem_scales_b[kNumStages]; 80 | 81 | // TMA Barrier for both divisible and non-divisible cases 82 | Barrier* full_barriers[kNumStages + 1]; 83 | Barrier* empty_barriers[kNumStages + 1]; 84 | 85 | // Fill shared memory pointers 86 | #pragma unroll 87 | for (int i = 0; i < kNumStages; ++ i) { 88 | smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); 89 | smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); 90 | smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) 91 | + i * SMEM_SCALES_A_SIZE_PER_STAGE); 92 | smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) 93 | + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE); 94 | } 95 | 96 | // Fill barriers 97 | DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); 98 | auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages 99 | * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE)); 100 | #pragma unroll 101 | for (int i = 0; i < kNumStages + 1; ++ i) { 102 | full_barriers[i] = barrier_start_ptr + i; 103 | empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i; 104 | } 105 | 106 | // Initialize barriers 107 | DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); 108 | if (threadIdx.x == kNumMathThreads) { 109 | // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, 110 | // even with TMA multicast disabled, we want to make the behavior aligned 111 | #pragma unroll 112 | for (int i = 0; i < kNumStages; ++ i) { 113 | full_barriers[i]->init(1); 114 | empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); 115 | } 116 | full_barriers[kNumStages]->init(1); 117 | empty_barriers[kNumStages]->init(1); 118 | 119 | // Make initialized barrier visible in async proxy 120 | cutlass::arch::fence_view_async_shared(); 121 | (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); 122 | } 123 | 124 | // Synchronize all threads to make barrier visible in normal memory model 125 | (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); 126 | 127 | // For pipeline unrolling 128 | struct DivisibleK {}; 129 | struct NotDivisibleK {}; 130 | auto launch_k_iterations = [&](const auto& func) { 131 | if constexpr (kNumLastStages == 0) { 132 | for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) 133 | func(k_iter, DivisibleK{}); 134 | } else { 135 | for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) 136 | func(k_iter, DivisibleK{}); 137 | func(num_iterations - 1, NotDivisibleK{}); 138 | } 139 | }; 140 | 141 | // Register reconfigurations 142 | constexpr int kNumTMARegisters = 40; 143 | constexpr int kNumMathRegisters = 232; 144 | 145 | // Block scheduler 146 | uint32_t m_block_idx, n_block_idx; 147 | auto scheduler = Scheduler(SHAPE_M); 148 | 149 | if (threadIdx.x >= kNumMathThreads) { 150 | // TMA warp-group for loading data 151 | cutlass::arch::warpgroup_reg_dealloc(); 152 | 153 | // NOTES: only one thread (or warp) will be used 154 | if (threadIdx.x == kNumMathThreads) { 155 | // Persistently schedule over blocks 156 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 157 | launch_k_iterations([&](int k_iter, auto type) { 158 | constexpr bool kHasDivisibleStages = std::is_same_v; 159 | constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; 160 | DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); 161 | 162 | // Assign TMA multicast number into A and B 163 | // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. 164 | const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); 165 | const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 166 | const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; 167 | DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); 168 | 169 | #pragma unroll 170 | for (uint32_t s = 0; s < kNumInnerStages; ++ s) { 171 | // Wait consumer release 172 | empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); 173 | 174 | // Issue TMA A 175 | auto& full_barrier = *full_barriers[s]; 176 | int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; 177 | tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), 178 | smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a); 179 | tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), 180 | smem_scales_a[s], m_block_idx * BLOCK_M, 181 | k_idx / BLOCK_K, num_tma_multicast_a); 182 | 183 | // Issue TMA B 184 | tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), 185 | smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b); 186 | tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), 187 | smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b); 188 | 189 | full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); 190 | } 191 | 192 | // Wait unaligned cases 193 | #pragma unroll 194 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 195 | empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); 196 | full_barriers[s]->arrive(); 197 | } 198 | }); 199 | 200 | // Issue TMA D 201 | empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1); 202 | auto& full_barrier = *full_barriers[kNumStages]; 203 | tma_copy(&tensor_map_d, reinterpret_cast(&full_barrier), 204 | smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1); 205 | full_barrier.arrive_and_expect_tx(SMEM_D_SIZE); 206 | } 207 | 208 | // To safely deconstruct distributed shared barriers, we need another round of empty waits 209 | if constexpr (kNumTMAMulticast > 1) { 210 | #pragma unroll 211 | for (uint32_t s = 0; s < kNumStages; ++ s) 212 | empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); 213 | } 214 | } 215 | } else { 216 | // Math warp-groups for WGMMA 217 | cutlass::arch::warpgroup_reg_alloc(); 218 | 219 | // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers 220 | const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); 221 | const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; 222 | const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; 223 | 224 | // Empty barrier arrival 225 | auto empty_barrier_arrive = [&](int s) { 226 | if constexpr (kNumTMAMulticast == 1) { 227 | lane_idx == 0 ? empty_barriers[s]->arrive() : void(); 228 | } else { 229 | auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); 230 | lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); 231 | } 232 | }; 233 | 234 | // Persistently schedule over blocks 235 | while (scheduler.get_next_block(m_block_idx, n_block_idx)) { 236 | // Decide the number of scales B to load 237 | DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); 238 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 239 | 240 | // Accumulation for WGMMA or CUDA promotion 241 | constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); 242 | float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; 243 | float2 scales_b[WGMMA::kNumAccum / 4]; 244 | 245 | // Launch MMAs 246 | launch_k_iterations([&](int k_iter, auto type) { 247 | constexpr bool kHasDivisibleStages = std::is_same_v; 248 | constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; 249 | DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); 250 | 251 | #pragma unroll 252 | for (int s = 0; s < kNumInnerStages; ++ s) { 253 | // Wait TMA arrivals 254 | full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); 255 | 256 | #pragma unroll 257 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 258 | auto m_offset = local_idx * WAVE_BLOCK_M; 259 | 260 | // Read A scales 261 | auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); 262 | auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); 263 | 264 | // Commit WGMMA instructions 265 | #pragma unroll 266 | for (int i = 0; i < WGMMA::kNumAccum; ++ i) 267 | warpgroup_fence_operand(accum[i]); 268 | warpgroup_arrive(); 269 | #pragma unroll 270 | for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { 271 | auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); 272 | auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); 273 | WGMMA::wgmma(desc_a, desc_b, accum, k); 274 | } 275 | warpgroup_commit_batch(); 276 | 277 | // Read B scales at the first warpgroup wave 278 | if (local_idx == 0) { 279 | #pragma unroll 280 | for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) 281 | scales_b[i] = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + col_idx * 2)); 282 | __syncwarp(); 283 | } 284 | 285 | #pragma unroll 286 | for (int i = 0; i < WGMMA::kNumAccum; ++ i) 287 | warpgroup_fence_operand(accum[i]); 288 | warpgroup_wait<0>(); 289 | 290 | // Notify barrier arrival at the last warpgroup wave 291 | if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) 292 | empty_barrier_arrive(s); 293 | 294 | // Promote with scales 295 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 296 | #pragma unroll 297 | for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 298 | const float &scale_b_0 = scales_b[i].x; 299 | const float &scale_b_1 = scales_b[i].y; 300 | shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; 301 | shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; 302 | shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; 303 | shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; 304 | } 305 | } 306 | } 307 | 308 | // Wait last TMA store to be finished 309 | if (k_iter == 0 and scheduler.current_iter > 0) { 310 | if (threadIdx.x == 0) { 311 | cute::tma_store_wait<0>(); 312 | empty_barriers[kNumStages]->arrive(); 313 | } 314 | __syncwarp(); 315 | } 316 | 317 | // Wait unaligned cases 318 | #pragma unroll 319 | for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { 320 | full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); 321 | empty_barrier_arrive(s); 322 | } 323 | }); 324 | 325 | // Wait TMA D arrivals 326 | full_barriers[kNumStages]->wait(scheduler.current_iter & 1); 327 | 328 | // Accumulate to D shared memory 329 | #pragma unroll 330 | for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { 331 | auto m_offset = local_idx * WAVE_BLOCK_M; 332 | auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; 333 | auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2); 334 | auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2); 335 | #pragma unroll 336 | for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 337 | float2 d_0 = ld_shared(smem_d_0 + i * 4); 338 | st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]}); 339 | float2 d_1 = ld_shared(smem_d_1 + i * 4); 340 | st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]}); 341 | } 342 | } 343 | 344 | cute::tma_store_fence(); 345 | cutlass::arch::NamedBarrier(kNumMathThreads).sync(); 346 | 347 | // Use TMA store to write back to global memory 348 | if (threadIdx.x == 0) { 349 | cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M); 350 | cute::tma_store_arrive(); 351 | } 352 | __syncwarp(); 353 | } 354 | } 355 | #else 356 | if (blockIdx.x == 0 and threadIdx.x == 0) 357 | DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); 358 | #endif 359 | } 360 | 361 | }; // namespace deep_gemm 362 | 363 | #pragma clang diagnostic pop 364 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/mma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef __CUDACC_RTC__ 4 | #include 5 | #endif 6 | 7 | #include 8 | #include 9 | 10 | #include "utils.cuh" 11 | 12 | namespace deep_gemm { 13 | 14 | template 15 | struct SM90_U32x2_STSM_N { 16 | __device__ __forceinline__ static void 17 | copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { 18 | const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; 19 | asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" 20 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); 21 | } 22 | }; 23 | 24 | template 25 | struct SM90_U32x4_STSM_N { 26 | __device__ __forceinline__ static void 27 | copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { 28 | const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), 29 | *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; 30 | asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" 31 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); 32 | } 33 | }; 34 | 35 | __forceinline__ __device__ void warpgroup_arrive() { 36 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 37 | } 38 | 39 | __forceinline__ __device__ void warpgroup_commit_batch() { 40 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 41 | } 42 | 43 | __forceinline__ __device__ void warpgroup_fence_operand(float& reg) { 44 | asm volatile("" : "+f"(reg) :: "memory"); 45 | } 46 | 47 | __forceinline__ __device__ uint32_t get_lane_id() { 48 | uint32_t lane_id; 49 | asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); 50 | return lane_id; 51 | } 52 | 53 | __device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { 54 | uint32_t ret; 55 | asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 56 | return ret; 57 | } 58 | 59 | __device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { 60 | int4 ret; 61 | asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); 62 | return ret; 63 | } 64 | 65 | __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { 66 | float ret; 67 | asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); 68 | return ret; 69 | } 70 | 71 | __device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { 72 | float2 ret; 73 | asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); 74 | return ret; 75 | } 76 | 77 | __device__ __forceinline__ void st_shared(const float* ptr, float val) { 78 | asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); 79 | } 80 | 81 | __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { 82 | asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); 83 | } 84 | 85 | __device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { 86 | asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); 87 | } 88 | 89 | template 90 | __device__ void warpgroup_wait() { 91 | DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); 92 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); 93 | } 94 | 95 | union GmmaDescriptor { 96 | __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} 97 | 98 | __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} 99 | 100 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} 101 | 102 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} 103 | 104 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { 105 | desc_ = t.desc_; 106 | return *this; 107 | } 108 | 109 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { 110 | desc_ = t.desc_; 111 | return *this; 112 | } 113 | 114 | uint64_t desc_; 115 | uint32_t reg32_[2]; 116 | uint16_t reg16_[4]; 117 | 118 | struct { 119 | uint16_t start_address_: 14, : 2; 120 | uint16_t leading_byte_offset_: 14, : 2; 121 | uint16_t stride_byte_offset_: 14, : 2; 122 | uint8_t : 1, base_offset_: 3, : 4; 123 | uint8_t : 6, layout_type_: 2; 124 | } bitfield; 125 | 126 | // Decay to an `uint64_t` 127 | __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } 128 | }; 129 | 130 | template 131 | __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, 132 | int leading_byte_offset = 0, 133 | int stride_byte_offset = 1024) { 134 | GmmaDescriptor desc; 135 | auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 136 | desc.bitfield.start_address_ = uint_ptr >> 4; 137 | desc.bitfield.layout_type_ = layout_type; 138 | desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; 139 | desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; 140 | desc.bitfield.base_offset_ = 0; 141 | return desc; 142 | } 143 | 144 | template 145 | struct FP8MMA { 146 | 147 | template 148 | __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { 149 | using namespace cute::SM90::GMMA; 150 | MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); 151 | } 152 | 153 | __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 154 | call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); 155 | } 156 | 157 | static constexpr int M = 64; 158 | static constexpr int N = N_; 159 | static constexpr int K = 32; 160 | static constexpr int kNumAccum = M * N / 128; 161 | }; 162 | 163 | template 164 | struct FP8MMASelector { 165 | 166 | static constexpr auto select_mma() { 167 | using namespace cute::SM90::GMMA; 168 | if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); 169 | if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); 170 | if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); 171 | if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); 172 | if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); 173 | if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); 174 | if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); 175 | if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); 176 | if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); 177 | if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); 178 | if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); 179 | if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); 180 | if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); 181 | if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); 182 | if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); 183 | if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); 184 | if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); 185 | if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); 186 | if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); 187 | if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); 188 | } 189 | 190 | static constexpr auto select_type() { 191 | return FP8MMA(); 192 | } 193 | 194 | using type = decltype(select_type()); 195 | }; 196 | 197 | enum class Layout { 198 | RowMajor, 199 | ColMajor 200 | }; 201 | 202 | __device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { 203 | return block_m == 64 ? 1 : 2; 204 | } 205 | 206 | template 207 | __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { 208 | DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); 209 | return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; 210 | } 211 | 212 | } // namespace deep_gemm 213 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/nvrtc_std.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 3 | * All rights reserved. SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #ifdef __CUDACC_RTC__ 21 | 22 | using int8_t = signed char; 23 | using uint8_t = unsigned char; 24 | using int16_t = signed short; 25 | using uint16_t = unsigned short; 26 | using int32_t = signed int; 27 | using uint32_t = unsigned int; 28 | using int64_t = signed long long; 29 | using uint64_t = unsigned long long; 30 | using cuuint64_t = unsigned long long; 31 | 32 | #ifndef CU_TENSOR_MAP_NUM_QWORDS 33 | #define CU_TENSOR_MAP_NUM_QWORDS 16 34 | 35 | struct CUtensorMap_st { 36 | #if defined(__cplusplus) && (__cplusplus >= 201103L) 37 | alignas(64) 38 | #elif __STDC_VERSION__ >= 201112L 39 | _Alignas(64) 40 | #endif 41 | cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; 42 | }; 43 | 44 | using CUtensorMap = CUtensorMap_st; 45 | #endif 46 | 47 | namespace std { 48 | 49 | template struct integral_constant { 50 | static constexpr T value = v; 51 | 52 | using value_type = T; 53 | using type = integral_constant; 54 | 55 | __device__ constexpr operator value_type() const noexcept { return value; } 56 | 57 | __device__ constexpr value_type operator()() const noexcept { return value; } 58 | }; 59 | 60 | using false_type = integral_constant; 61 | using true_type = integral_constant; 62 | 63 | template struct is_same : false_type {}; 64 | 65 | template struct is_same : true_type {}; 66 | 67 | template 68 | inline constexpr bool is_same_v = is_same::value; 69 | 70 | namespace index_sequence_impl { 71 | 72 | // Based on https://stackoverflow.com/a/32223343/11717224 73 | template struct index_sequence { 74 | using type = index_sequence; 75 | using value_type = size_t; 76 | static constexpr size_t size() noexcept { return sizeof...(Ints); } 77 | }; 78 | 79 | template struct _merge_and_renumber; 80 | 81 | template 82 | struct _merge_and_renumber, index_sequence> 83 | : index_sequence {}; 84 | 85 | template 86 | struct make_index_sequence 87 | : _merge_and_renumber::type, 88 | typename make_index_sequence::type> {}; 89 | 90 | template <> struct make_index_sequence<0> : index_sequence<> {}; 91 | template <> struct make_index_sequence<1> : index_sequence<0> {}; 92 | 93 | } // namespace index_sequence_impl 94 | 95 | template 96 | using index_sequence = index_sequence_impl::index_sequence; 97 | 98 | template 99 | using make_index_sequence = index_sequence_impl::make_index_sequence; 100 | 101 | } // namespace std 102 | 103 | #endif 104 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/scheduler.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils.cuh" 4 | 5 | namespace deep_gemm { 6 | 7 | enum class GemmType { 8 | Normal, 9 | GroupedContiguous, 10 | GroupedMasked 11 | }; 12 | 13 | #pragma clang diagnostic push 14 | #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" 15 | template 21 | struct Scheduler { 22 | int current_iter = -1; 23 | uint32_t num_aligned_m_blocks; 24 | 25 | // For normal GEMM 26 | // Maybe not used in the masked grouped GEMM 27 | uint32_t num_blocks; 28 | uint32_t num_blocks_in_group; 29 | bool is_peer_cta_alive = true; 30 | 31 | // For grouped GEMM 32 | int* grouped_layout; 33 | 34 | // Only used for masked layout 35 | uint32_t curr_group_idx, curr_cumsum; 36 | 37 | __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, 38 | int* grouped_layout = nullptr) { 39 | num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); 40 | if constexpr (kGemmType == GemmType::Normal) { 41 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 42 | } else if (kGemmType == GemmType::GroupedContiguous) { 43 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 44 | this->grouped_layout = grouped_layout; 45 | } else if (kGemmType == GemmType::GroupedMasked) { 46 | curr_group_idx = curr_cumsum = 0; 47 | this->grouped_layout = grouped_layout; 48 | } 49 | } 50 | 51 | // ReSharper disable once CppNotAllPathsReturnValue 52 | __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { 53 | if constexpr (kGemmType == GemmType::Normal) { 54 | return true; 55 | } else if constexpr (kGemmType == GemmType::GroupedContiguous) { 56 | return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; 57 | } else if constexpr (kGemmType == GemmType::GroupedMasked) { 58 | return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); 59 | } 60 | } 61 | 62 | __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { 63 | if (num_blocks_in_group == 1) 64 | return false; 65 | if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { 66 | return true; 67 | } else { 68 | DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); 69 | if constexpr (kIsTMAMulticastOnA) { 70 | return true; 71 | } else { 72 | auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); 73 | auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); 74 | return group_idx == peer_group_idx; 75 | } 76 | } 77 | } 78 | 79 | __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx, 80 | uint32_t& m_block_idx, uint32_t& n_block_idx) { 81 | DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); 82 | 83 | // Swizzle for better L2 usages 84 | auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; 85 | auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; 86 | auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; 87 | auto group_idx = block_idx / num_blocks_per_group; 88 | auto first_block_idx = group_idx * kNum1DBlocksPerGroup; 89 | auto in_group_idx = block_idx % num_blocks_per_group; 90 | num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); 91 | 92 | // Fix unaligned TMA multicast 93 | if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { 94 | if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { 95 | num_blocks_in_group = num_blocks_in_group ^ 1; 96 | } else { 97 | in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; 98 | first_block_idx += num_blocks_in_group ^ 1; 99 | num_blocks_in_group = 1; 100 | } 101 | } 102 | 103 | // Convert to final M/N block indices 104 | if constexpr (kIsTMAMulticastOnA) { 105 | m_block_idx = in_group_idx / num_blocks_in_group; 106 | n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; 107 | } else { 108 | m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; 109 | n_block_idx = in_group_idx / num_blocks_in_group; 110 | } 111 | } 112 | 113 | template 114 | __device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size, 115 | const uint32_t& block_idx, const uint32_t& m_block_idx=0) { 116 | if constexpr (kGemmType == GemmType::Normal) { 117 | return block_idx * block_size; 118 | } else if constexpr (kGemmType == GemmType::GroupedContiguous) { 119 | auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); 120 | return offset * shape_dim + block_idx * block_size; 121 | } else if constexpr (kGemmType == GemmType::GroupedMasked) { 122 | return curr_group_idx * shape_dim + block_idx * block_size; 123 | } 124 | } 125 | 126 | __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { 127 | const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; 128 | 129 | if constexpr (kGemmType == GemmType::GroupedMasked) { 130 | uint32_t num_m_blocks; 131 | while (true) { 132 | // End of the task 133 | if (curr_group_idx == kNumGroups) 134 | return false; 135 | 136 | // Within the current group 137 | num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); 138 | auto current_m_block_cumsum = curr_cumsum + num_m_blocks; 139 | if (next_block_idx < current_m_block_cumsum * kNumNBlocks) 140 | break; 141 | 142 | // Move to check the next group 143 | curr_group_idx ++, curr_cumsum = current_m_block_cumsum; 144 | } 145 | 146 | get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); 147 | } else { 148 | if (next_block_idx >= num_blocks) 149 | return false; 150 | 151 | // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned 152 | is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) 153 | num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) 154 | (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound 155 | get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); 156 | } 157 | return true; 158 | } 159 | }; 160 | 161 | #pragma clang diagnostic pop 162 | 163 | } // namespace deep_gemm 164 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/tma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils.cuh" 4 | 5 | namespace deep_gemm { 6 | 7 | // TODO: move this function to other files 8 | __device__ __forceinline__ void 9 | tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, 10 | int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { 11 | constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); 12 | if (num_tma_multicast == 1) { 13 | cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); 14 | } else if (cute::block_rank_in_cluster() == 0) { 15 | cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); 16 | } 17 | } 18 | 19 | } // namespace deep_gemm 20 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __CLION_IDE__ 4 | 5 | __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { 6 | asm volatile("trap;"); 7 | } 8 | 9 | #define printf host_device_printf 10 | #endif 11 | 12 | #ifndef DG_DEVICE_ASSERT 13 | #define DG_DEVICE_ASSERT(cond) \ 14 | do { \ 15 | if (not (cond)) { \ 16 | printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ 17 | asm("trap;"); \ 18 | } \ 19 | } while (0) 20 | #endif 21 | 22 | #ifndef DG_STATIC_ASSERT 23 | #define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) 24 | #endif 25 | 26 | template 27 | __device__ __host__ constexpr T ceil_div(T a, T b) { 28 | return (a + b - 1) / b; 29 | } 30 | 31 | template 32 | __device__ __host__ constexpr T constexpr_gcd(T a, T b) { 33 | return b == 0 ? a : constexpr_gcd(b, a % b); 34 | } 35 | -------------------------------------------------------------------------------- /deep_gemm/jit/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler 2 | from .runtime import Runtime 3 | -------------------------------------------------------------------------------- /deep_gemm/jit/compiler.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import os 4 | import re 5 | import subprocess 6 | import time 7 | import uuid 8 | from typing import Any, Dict, List, Tuple, Type 9 | 10 | import cuda.bindings 11 | import cuda.bindings.nvrtc as nvrtc 12 | from torch.utils.cpp_extension import CUDA_HOME 13 | 14 | from . import interleave_ffma 15 | from .runtime import Runtime, RuntimeCache 16 | 17 | runtime_cache = RuntimeCache() 18 | 19 | 20 | def hash_to_hex(s: str) -> str: 21 | md5 = hashlib.md5() 22 | md5.update(s.encode('utf-8')) 23 | return md5.hexdigest()[0:12] 24 | 25 | 26 | @functools.lru_cache(maxsize=None) 27 | def get_jit_include_dir() -> str: 28 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') 29 | 30 | 31 | @functools.lru_cache(maxsize=None) 32 | def get_deep_gemm_version() -> str: 33 | md5 = hashlib.md5() 34 | 35 | # Update include directories 36 | include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') 37 | assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' 38 | for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): 39 | with open(os.path.join(include_dir, filename), 'rb') as f: 40 | md5.update(f.read()) 41 | 42 | # Update `interleave_ffma.py` 43 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: 44 | md5.update(f.read()) 45 | return md5.hexdigest()[0:12] 46 | 47 | 48 | @functools.lru_cache(maxsize=None) 49 | def get_nvcc_compiler() -> Tuple[str, str]: 50 | paths = [] 51 | if os.getenv('DG_JIT_NVCC_COMPILER'): 52 | paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) 53 | paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) 54 | 55 | # Try to find the first available NVCC compiler 56 | least_version_required = '12.3' 57 | version_pattern = re.compile(r'release (\d+\.\d+)') 58 | for path in paths: 59 | if os.path.exists(path): 60 | command = [path, '--version'] 61 | result = subprocess.run(command, stdout=subprocess.PIPE, 62 | stderr=subprocess.PIPE, text=True) 63 | match = version_pattern.search(result.stdout) 64 | version = match.group(1) 65 | assert match, f'Cannot get the version of NVCC compiler {path}' 66 | assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' 67 | return path, version 68 | raise RuntimeError('Cannot find any available NVCC compiler') 69 | 70 | 71 | @functools.lru_cache(maxsize=None) 72 | def get_default_user_dir(): 73 | if 'DG_JIT_CACHE_DIR' in os.environ: 74 | path = os.getenv('DG_JIT_CACHE_DIR') 75 | os.makedirs(path, exist_ok=True) 76 | return path 77 | return os.path.join(os.path.expanduser('~'), '.deep_gemm') 78 | 79 | 80 | @functools.lru_cache(maxsize=None) 81 | def get_tmp_dir(): 82 | return os.path.join(get_default_user_dir(), 'tmp') 83 | 84 | 85 | @functools.lru_cache(maxsize=None) 86 | def get_cache_dir(): 87 | return os.path.join(get_default_user_dir(), 'cache') 88 | 89 | 90 | def make_tmp_dir(): 91 | tmp_dir = get_tmp_dir() 92 | os.makedirs(tmp_dir, exist_ok=True) 93 | return tmp_dir 94 | 95 | 96 | def put(path, data): 97 | # Write and do POSIX atomic replace 98 | tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') 99 | with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: 100 | f.write(data) 101 | os.replace(tmp_file_path, path) 102 | 103 | 104 | class Compiler: 105 | @classmethod 106 | def signature(cls) -> str: 107 | pass 108 | 109 | @staticmethod 110 | def __version__() -> Tuple[int, int]: 111 | pass 112 | 113 | @classmethod 114 | def compile(cls, name: str, code: str, target_path: str) -> None: 115 | pass 116 | 117 | @staticmethod 118 | def flags() -> List[str]: 119 | cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) 120 | return [f'-std=c++{cpp_standard}', 121 | '--ptxas-options=--register-usage-level=10' + 122 | (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), 123 | # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases 124 | '--diag-suppress=39,161,174,177,186,940'] 125 | 126 | @staticmethod 127 | def include_dirs() -> List[str]: 128 | return [get_jit_include_dir()] 129 | 130 | @classmethod 131 | def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: 132 | # Compiler flags 133 | flags = cls.flags() 134 | 135 | # Build signature 136 | enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) 137 | signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' 138 | name = f'kernel.{name}.{hash_to_hex(signature)}' 139 | path = os.path.join(get_cache_dir(), name) 140 | 141 | # Check runtime cache or file system hit 142 | global runtime_cache 143 | cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs) 144 | if cached_runtime is not None: 145 | if int(os.getenv('DG_JIT_DEBUG', 0)): 146 | print(f'Using cached JIT runtime {name} during build') 147 | return cached_runtime 148 | 149 | # Compile into a temporary CU file 150 | os.makedirs(path, exist_ok=True) 151 | cubin_path = os.path.join(path, 'kernel.cubin') 152 | tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') 153 | 154 | start_time = time.time() 155 | cls.compile(name, code, tmp_cubin_path) 156 | end_time = time.time() 157 | elapsed_time = end_time - start_time 158 | if int(os.getenv('DG_JIT_DEBUG', 0)): 159 | print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') 160 | 161 | # Interleave FFMA reuse 162 | if enable_sass_opt: 163 | interleave_ffma.process(tmp_cubin_path) 164 | 165 | # Atomic replace files 166 | os.replace(tmp_cubin_path, cubin_path) 167 | 168 | # Put cache and return 169 | runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True) 170 | assert runtime is not None 171 | return runtime 172 | 173 | 174 | class NVCCCompiler(Compiler): 175 | @staticmethod 176 | def __version__() -> Tuple[int, int]: 177 | _, version = get_nvcc_compiler() 178 | major, minor = map(int, version.split('.')) 179 | return major, minor 180 | 181 | @classmethod 182 | def signature(cls) -> str: 183 | return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' 184 | 185 | @classmethod 186 | def flags(cls) -> List[str]: 187 | cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] 188 | return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], 189 | '-gencode=arch=compute_90a,code=sm_90a', 190 | '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', 191 | f'--compiler-options={",".join(cxx_flags)}'] 192 | 193 | @classmethod 194 | def compile(cls, name: str, code: str, target_path: str) -> None: 195 | # Write the code 196 | path = os.path.join(get_cache_dir(), name) 197 | src_path = os.path.join(path, 'kernel.cu') 198 | put(src_path, code) 199 | command = [get_nvcc_compiler()[0], 200 | src_path, '-o', target_path, 201 | *cls.flags()] 202 | if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): 203 | print(f'Compiling JIT runtime {name} with command {command}') 204 | 205 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 206 | if result.returncode != 0: 207 | print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') 208 | assert False, f'Failed to compile {src_path}' 209 | 210 | 211 | class NVRTCCompiler(Compiler): 212 | @staticmethod 213 | def __version__() -> Tuple[int, int]: 214 | res, major, minor = nvrtc.nvrtcVersion() 215 | if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: 216 | # Failed to get the actual NVRTC version, use cuda-bindings version instead 217 | major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) 218 | return major, minor 219 | 220 | @classmethod 221 | def signature(cls) -> str: 222 | return f'nvrtc+{cls.__version__()}' 223 | 224 | @staticmethod 225 | def include_dirs() -> List[str]: 226 | if CUDA_HOME is None: 227 | raise RuntimeError('CUDA_HOME is required for NVRTC compilation') 228 | return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] 229 | 230 | @classmethod 231 | def flags(cls) -> List[str]: 232 | flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], 233 | '--gpu-architecture=sm_90a', '-default-device'] 234 | # NOTES: PCH is vital for compilation speed 235 | if cls.__version__() >= (12, 8): 236 | flags += ['--pch'] 237 | if int(os.getenv('DG_JIT_DEBUG', 0)): 238 | flags += ['--pch-verbose=true'] 239 | return flags 240 | 241 | @classmethod 242 | def compile(cls, name: str, code: str, target_path: str) -> None: 243 | # Create program 244 | code_bytes = bytes(code, 'utf-8') 245 | result, program = nvrtc.nvrtcCreateProgram( 246 | code_bytes, bytes(name, 'utf-8'), 0, [], []) 247 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' 248 | 249 | # Compile 250 | options = [bytes(flag, 'utf-8') for flag in cls.flags()] 251 | if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): 252 | print(f'Compiling JIT runtime {name} with options: {options}') 253 | compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] 254 | 255 | # Print compiler log 256 | if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: 257 | result, log_size = nvrtc.nvrtcGetProgramLogSize(program) 258 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' 259 | 260 | log_bytes = bytes(log_size) 261 | result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] 262 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' 263 | print(f'Compiler log: {log_bytes.decode("utf-8")}') 264 | 265 | # Exit if failed 266 | assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' 267 | 268 | # Create CUBIN 269 | result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) 270 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' 271 | cubin_bytes = bytes(cubin_size) 272 | result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] 273 | assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' 274 | 275 | # Write into the file system 276 | put(target_path, cubin_bytes) 277 | 278 | # Destroy handler 279 | assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' 280 | 281 | 282 | def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: 283 | compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler 284 | return compiler_cls.build(name, code, runtime_cls, kwargs) 285 | -------------------------------------------------------------------------------- /deep_gemm/jit/interleave_ffma.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mmap 3 | import os 4 | import re 5 | import subprocess 6 | from torch.utils.cpp_extension import CUDA_HOME 7 | 8 | 9 | def run_cuobjdump(file_path): 10 | command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] 11 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 12 | assert result.returncode == 0 13 | return result.stdout 14 | 15 | 16 | def extract_ffma(sass): 17 | lines = sass.splitlines() 18 | collected = [] 19 | current = [] 20 | 21 | arch_name, func_name = 'N/A', 'N/A' 22 | skip_next_line = False 23 | for line in lines: 24 | if 'code for' in line: 25 | arch_name = line.lstrip().lstrip('code for ').rstrip() 26 | elif 'Function :' in line: 27 | func_name = line.lstrip().lstrip('Function :').rstrip() 28 | elif 'FFMA' in line: 29 | current.append(line) 30 | skip_next_line = True 31 | elif skip_next_line: 32 | current.append(line) 33 | skip_next_line = False 34 | else: 35 | if len(current) >= 16: 36 | assert len(current) % 2 == 0 37 | collected.append((f'{arch_name}::{func_name}', current)) 38 | current = [] 39 | 40 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 41 | print(f'Found {len(collected)} FFMA segments') 42 | return collected 43 | 44 | 45 | def extract_hex_from_line(line): 46 | match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) 47 | assert match 48 | return int(match.group(1), 16) 49 | 50 | 51 | def validate(m, offset, le_bytes, num_lines): 52 | assert len(le_bytes) == num_lines // 2 53 | assert m[offset:offset + 16] == le_bytes[0] 54 | for i in range(1, num_lines // 2): 55 | if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: 56 | return False 57 | return True 58 | 59 | 60 | def parse_registers(line): 61 | line = re.sub(r'/\*.*?\*/', '', line) 62 | line = line.replace(';', '') 63 | tokens = line.strip().split(',') 64 | registers = [] 65 | for token in tokens: 66 | token = token.strip() 67 | words = token.split() 68 | for word in words: 69 | if word.startswith('R'): 70 | reg = word.split('.')[0] 71 | registers.append(reg) 72 | return registers 73 | 74 | 75 | def modify_segment(m, name, ffma_lines): 76 | num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2 77 | assert num_lines % 2 == 0 78 | 79 | le_bytes, new_le_bytes = [], [] 80 | reused_list = [] 81 | dst_reg_set = set() 82 | last_reused, last_dst_reg = False, '' 83 | num_changed = 0 84 | for i in range(num_lines // 2): 85 | dst_reg = parse_registers(ffma_lines[i * 2])[-2] 86 | low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] 87 | low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) 88 | le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 89 | reused = (high_hex & 0x0800000000000000) != 0 90 | if reused: 91 | is_first_occurred = dst_reg not in dst_reg_set 92 | if is_first_occurred or (last_reused and dst_reg == last_dst_reg): 93 | # Modify the `reuse` and `yield` bits 94 | assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' 95 | high_hex ^= 0x0800200000000000 96 | reused = False 97 | num_changed += 1 98 | else: 99 | reused_list.append(i) 100 | dst_reg_set.add(dst_reg) 101 | new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 102 | last_reused, last_dst_reg = reused, dst_reg 103 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 104 | print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') 105 | 106 | # Find the offset 107 | offsets = [] 108 | offset = m.find(le_bytes[0]) 109 | while offset != -1: 110 | offsets.append(offset) 111 | offset = m.find(le_bytes[0], offset + 1) 112 | offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) 113 | 114 | # Replace with `new_le_bytes` 115 | for offset in offsets: 116 | for i in range(num_lines // 2): 117 | m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] 118 | 119 | 120 | def process(path): 121 | if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): 122 | print(f'Processing {path}') 123 | output = run_cuobjdump(path) 124 | segments = extract_ffma(output) 125 | with open(path, 'r+b') as f: 126 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) 127 | for segment in segments: 128 | modify_segment(mm, *segment) 129 | mm.close() 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') 134 | parser.add_argument('--so', help='Path to the SO file') 135 | args = parser.parse_args() 136 | 137 | process(args.so) 138 | -------------------------------------------------------------------------------- /deep_gemm/jit/runtime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | import torch 5 | import cuda.bindings.driver as cbd 6 | 7 | from typing import Any, Dict, Optional, Type 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | 10 | 11 | class Runtime: 12 | def __init__(self, path: str) -> None: 13 | self.path = path 14 | self.lib = None 15 | self.kernel = None 16 | assert self.is_path_valid(self.path) 17 | 18 | @staticmethod 19 | def is_path_valid(path: str) -> bool: 20 | # Exists and is a directory 21 | if not os.path.exists(path) or not os.path.isdir(path): 22 | return False 23 | 24 | # Contains all necessary files 25 | files = ['kernel.cubin'] 26 | return all(os.path.exists(os.path.join(path, file)) for file in files) 27 | 28 | @staticmethod 29 | def generate(kwargs: Dict[str, Any]) -> str: 30 | raise NotImplemented 31 | 32 | @staticmethod 33 | def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: 34 | raise NotImplemented 35 | 36 | def __call__(self, **kwargs) -> cbd.CUresult: 37 | # Load CUBIN 38 | if self.kernel is None: 39 | start_time = time.time_ns() 40 | 41 | # Load CUBIN 42 | path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') 43 | result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) 44 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' 45 | 46 | # Extract the kernel name 47 | # TODO: use `cuda-bindings` API to do this (requires at least 12.8) 48 | command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] 49 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 50 | assert result.returncode == 0 51 | illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail'] 52 | check_illegal = lambda line: any([name in line for name in illegal_names]) 53 | kernel_names = [line.split()[-1] for line in result.stdout.splitlines() 54 | if line.startswith('STT_FUNC') and not check_illegal(line)] 55 | assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' 56 | 57 | # Load kernel from the library 58 | result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) 59 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' 60 | 61 | end_time = time.time_ns() 62 | elapsed_time = (end_time - start_time) / 1e6 63 | if int(os.getenv('DG_JIT_DEBUG', 0)): 64 | print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') 65 | 66 | # noinspection PyArgumentList 67 | return self.launch(self.kernel, kwargs) 68 | 69 | def __del__(self) -> None: 70 | if self.lib is not None: 71 | res = cbd.cuLibraryUnload(self.lib)[0] 72 | if res != cbd.CUresult.CUDA_SUCCESS: 73 | raise Exception(f'Failed to unload library {self.path}: {res}') 74 | 75 | 76 | class RuntimeCache: 77 | def __init__(self) -> None: 78 | self.cache = {} 79 | 80 | def __setitem__(self, path: str, runtime: Runtime) -> None: 81 | self.cache[path] = runtime 82 | 83 | def get(self, path: str, runtime_cls: Type[Runtime], 84 | name: str = '', kwargs: Dict[str, Any] = None, 85 | force_enable_cache: bool = False) -> Optional[Runtime]: 86 | # In Python runtime 87 | if path in self.cache: 88 | return self.cache[path] 89 | 90 | # Already compiled 91 | use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) 92 | if use_cache and os.path.exists(path) and Runtime.is_path_valid(path): 93 | # Print heuristic for the first time 94 | if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): 95 | simplified_kwargs = dict() 96 | for key, value in kwargs.items() if kwargs is not None else dict().items(): 97 | value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value 98 | value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value 99 | simplified_kwargs[key] = value 100 | print(f'Put kernel {name} with {simplified_kwargs} into runtime cache') 101 | 102 | runtime = runtime_cls(path) 103 | self.cache[path] = runtime 104 | return runtime 105 | return None 106 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemm import gemm_fp8_fp8_bf16_nt 2 | from .m_grouped_gemm import ( 3 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 4 | m_grouped_gemm_fp8_fp8_bf16_nt_masked 5 | ) 6 | from .wgrad_gemm import ( 7 | wgrad_gemm_fp8_fp8_fp32_nt, 8 | k_grouped_wgrad_gemm_fp8_fp8_fp32_nt 9 | ) 10 | from .utils import ( 11 | ceil_div, set_num_sms, get_num_sms, 12 | get_col_major_tma_aligned_tensor, 13 | get_m_alignment_for_contiguous_layout 14 | ) 15 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/gemm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import lru_cache 4 | from typing import Tuple 5 | 6 | from ..jit import build 7 | from .runtime import ( 8 | FP8GemmRuntime, GemmType, 9 | make_2d_tma_a_desc, make_2d_tma_b_desc, 10 | make_2d_tma_d_desc, make_2d_tma_scales_desc) 11 | from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout 12 | 13 | 14 | def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, 15 | require_divisible: bool = False) -> bool: 16 | divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible 17 | return divisible and num_sms % num_tma_multicast == 0 18 | 19 | 20 | def get_swizzle_mode(block_n: int) -> int: 21 | elem_size = 2 22 | for mode_bytes in (128, 64, 32): 23 | if (block_n * elem_size) % mode_bytes == 0: 24 | return mode_bytes 25 | return 0 26 | 27 | 28 | def get_block_n_padding_for_smem_d(block_n: int) -> int: 29 | # NOTES: padding is for solving bank conflicts, but wastes shared memory space 30 | elem_size, requirement = 2, (4, 8) 31 | bank_stride = (block_n * elem_size) // 4 32 | padding = (requirement[0] - bank_stride) % requirement[1] 33 | return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size 34 | 35 | 36 | def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, 37 | is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: 38 | assert block_k == 128 39 | 40 | # Try swizzle first, as it does not waste shared memory 41 | swizzle_mode = get_swizzle_mode(block_n) 42 | block_n_padding = get_block_n_padding_for_smem_d( 43 | block_n) if swizzle_mode == 0 else 0 44 | 45 | # NOTES: `scales_b` in a total manner or per-stage manner 46 | smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) 47 | smem_a_per_stage = block_m * block_k 48 | smem_scales_a_per_stage = block_m * 4 49 | smem_b_per_stage = block_n * block_k 50 | smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 51 | smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 52 | smem_barrier = num_stages * 8 * 2 53 | 54 | smem_size = 0 55 | smem_size += smem_d 56 | smem_size += num_stages * smem_a_per_stage 57 | smem_size += num_stages * smem_scales_a_per_stage 58 | smem_size += num_stages * smem_b_per_stage 59 | smem_size += num_stages * smem_scales_b_per_stage 60 | smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 61 | smem_size += smem_barrier 62 | 63 | # Swizzle and padding are not compatible 64 | assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 65 | 66 | return smem_size, swizzle_mode, block_n_padding 67 | 68 | 69 | @lru_cache(maxsize=None) 70 | def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, 71 | is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, 72 | is_fp32_out: bool = False, is_wgrad: bool = False) -> \ 73 | Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: 74 | if not is_grouped_contiguous: 75 | block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) 76 | else: 77 | block_ms = (get_m_alignment_for_contiguous_layout(), ) 78 | block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) 79 | 80 | # Avoid bank conflicts for FP32 output 81 | if is_fp32_out: 82 | block_ns = [x for x in block_ns if x % 16 == 8] 83 | 84 | fix_wave_saturate = lambda x: num_sms if x == 0 else x 85 | get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) 86 | get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) 87 | 88 | # Decide block sizes by waves 89 | best_block_m, best_block_n = None, None 90 | for block_m in block_ms: 91 | # NOTES: the block sizes cannot be too large, so at least one dim less than 128 92 | for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): 93 | success = False 94 | num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) 95 | if best_block_m is None or best_block_n is None: 96 | success = True 97 | elif num_waves < best_num_waves: 98 | success = True 99 | elif num_waves == best_num_waves: 100 | # Check last wave utilization 101 | util = get_last_wave_util(block_m, block_n) 102 | best_util = get_last_wave_util(best_block_m, best_block_n) 103 | success = util > best_util 104 | if util == best_util: 105 | # Case 1: same `block_m`, smaller `block_n` (wasted) 106 | success |= block_m == best_block_m and block_n < best_block_n 107 | # Case 2: same `block_n`, smaller `block_m` (wasted) 108 | success |= block_n == best_block_n and block_m < best_block_m 109 | # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better 110 | success |= block_m != best_block_m and block_n > best_block_n 111 | best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) 112 | assert best_block_m is not None and best_block_n is not None 113 | 114 | # Always pick the longest one 115 | # NOTES: for double B scales, the best number of stages may be reduced 116 | best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 117 | stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))) 118 | if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: 119 | # Unrolling both stages and `num_former_iters` will cause large code size 120 | stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) 121 | for num_stages in stage_candidates: 122 | best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) 123 | if best_smem_config[0] <= sm90_capacity: 124 | best_num_stages = num_stages 125 | break 126 | assert best_smem_config is not None 127 | assert best_num_stages is not None 128 | 129 | # Decide the number of TMA multicasts and whether broadcast on A 130 | best_tma_multicast_config = (1, True) 131 | 132 | # Try to multicast on the larger block side first 133 | # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even 134 | is_multicast_legal = { 135 | 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), 136 | 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, 137 | } 138 | for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): 139 | if m >= 512 and is_multicast_legal[i]: 140 | best_tma_multicast_config = (2, i == 'A') 141 | break 142 | 143 | # Recompute the minimal number of SMs required 144 | # NOTES: less L2 cache usage and less GPU frequency drop 145 | num_waves = get_num_waves(best_block_m, best_block_n) 146 | num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) 147 | num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] 148 | assert num_min_sms <= num_sms 149 | 150 | return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config 151 | 152 | 153 | def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 154 | rhs: Tuple[torch.Tensor, torch.Tensor], 155 | out: torch.Tensor) -> None: 156 | """ 157 | Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 158 | 159 | Requirements: 160 | LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. 161 | The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8. 162 | RHS and RHS scaling factors are required to be transposed. 163 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 164 | this function will do a transposing with a set of slow PyTorch operations. 165 | 166 | Arguments: 167 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, 168 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. 169 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, 170 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. 171 | out: the BF16 output tensor of shape `[m, n]`, representing the result. 172 | """ 173 | lhs, lhs_scales = lhs 174 | rhs, rhs_scales = rhs 175 | m, k = lhs.shape 176 | n, k_ = rhs.shape 177 | m_, n_ = out.shape 178 | 179 | # Type and shape checks 180 | assert m == m_ and n == n_ and k == k_ 181 | assert n > 0 and k > 0 182 | assert lhs_scales.shape == (m, ceil_div(k, 128)) 183 | assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128)) 184 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 185 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 186 | assert out.dtype == torch.bfloat16 187 | assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 188 | 189 | # LHS scales must be transposed for TMA loads, but not for RHS scales 190 | # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels 191 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 192 | assert rhs_scales.is_contiguous() 193 | 194 | # Do nothing if `m` is zero 195 | if m == 0: 196 | return 197 | 198 | # K must be aligned to 128 199 | aligned_k = ceil_div(k, 128) * 128 200 | 201 | # Auto-tuning with compilation 202 | num_sms = get_num_sms() 203 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) 204 | block_k = 128 205 | num_tma_threads = 128 206 | num_math_threads_per_group = 128 207 | 208 | tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) 209 | tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) 210 | tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) 211 | tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) 212 | 213 | kwargs = { 214 | # Templated arguments 215 | 'GEMM_TYPE': GemmType.Normal, 216 | 'NUM_TMA_THREADS': num_tma_threads, 217 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 218 | 'M': m, 'N': n, 'K': aligned_k, 219 | 'NUM_GROUPS': 1, 220 | 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 221 | 'SWIZZLE_D_MODE': smem_config[1], 222 | 'BLOCK_N_PADDING': smem_config[2], 223 | 'NUM_STAGES': num_stages, 224 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 225 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 226 | # Runtime arguments 227 | 'SCALES_B': rhs_scales, 228 | 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), 229 | 'NUM_SMS': num_sms, 230 | 'SMEM_SIZE': smem_config[0], 231 | 'TENSOR_MAP_A': tensor_map_a, 232 | 'TENSOR_MAP_B': tensor_map_b, 233 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 234 | 'TENSOR_MAP_D': tensor_map_d, 235 | 'STREAM': torch.cuda.current_stream().cuda_stream, 236 | 'DEVICE_INDEX': out.device.index 237 | } 238 | 239 | # Generate, build and run the kernel 240 | code = FP8GemmRuntime.generate(kwargs) 241 | runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) 242 | runtime(**kwargs) 243 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/m_grouped_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from ..jit import build 5 | from .gemm import get_best_configs 6 | from .runtime import ( 7 | FP8GemmRuntime, GemmType, 8 | make_2d_tma_a_desc, make_2d_tma_b_desc, 9 | make_2d_tma_d_desc, make_2d_tma_scales_desc) 10 | from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms 11 | 12 | 13 | def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], 14 | rhs: Tuple[torch.Tensor, torch.Tensor], 15 | out: torch.Tensor, m_indices: torch.Tensor) -> None: 16 | """ 17 | Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 18 | 19 | Requirements: 20 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 21 | RHS and RHS scaling factors are required to be transposed. 22 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 23 | this function will do a transposing with a set of slow PyTorch operations. 24 | On the M axis, inputs are grouped into several batches, of which batch sizes aligned to 25 | `get_m_alignment_for_contiguous_layout()` (128). 26 | 27 | Arguments: 28 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, 29 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. 30 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, 31 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 32 | out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. 33 | m_indices: a tensor of shape `[m_sum]` with type `torch.int`. 34 | `m_indices[i]` records the group which the i-th row of the LHS belongs to, 35 | which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. 36 | Values of `m_indices` in every-m-alignment-block must also be the same. 37 | """ 38 | lhs, lhs_scales = lhs 39 | rhs, rhs_scales = rhs 40 | m, k = lhs.shape 41 | num_groups, n, k_ = rhs.shape 42 | m_, n_ = out.shape 43 | m__ = m_indices.numel() 44 | 45 | # Type and shape checks 46 | assert m == m_ == m__ and k == k_ and n == n_ 47 | assert lhs_scales.shape == (m, ceil_div(k, 128)) 48 | assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) 49 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 50 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 51 | assert out.dtype == torch.bfloat16 52 | assert m_indices.dtype == torch.int32 53 | assert lhs.is_contiguous() and rhs.is_contiguous() 54 | assert out.is_contiguous() and m_indices.is_contiguous() 55 | 56 | # LHS scales must be transposed for TMA load, but not for RHS scales 57 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 58 | assert rhs_scales.is_contiguous() 59 | 60 | # Do nothing if `m` is zero 61 | if m == 0: 62 | return 63 | 64 | # Auto-tuning with compilation 65 | num_sms = get_num_sms() 66 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 67 | m, n, k, 1, num_sms, is_grouped_contiguous=True) 68 | block_k = 128 69 | num_tma_threads = 128 70 | num_math_threads_per_group = 128 71 | 72 | tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups) 73 | tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups) 74 | tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) 75 | tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) 76 | 77 | kwargs = { 78 | # Templated arguments 79 | 'NUM_TMA_THREADS': num_tma_threads, 80 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 81 | 'M': m, 'N': n, 'K': k, 82 | 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 83 | 'SWIZZLE_D_MODE': smem_config[1], 84 | 'BLOCK_N_PADDING': smem_config[2], 85 | 'NUM_GROUPS': num_groups, 86 | 'NUM_STAGES': num_stages, 87 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 88 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 89 | 'GEMM_TYPE': GemmType.GroupedContiguous, 90 | # Runtime arguments 91 | 'SCALES_B': rhs_scales, 92 | 'GROUPED_LAYOUT': m_indices, 93 | 'NUM_SMS': num_sms, 94 | 'SMEM_SIZE': smem_config[0], 95 | 'TENSOR_MAP_A': tensor_map_a, 96 | 'TENSOR_MAP_B': tensor_map_b, 97 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 98 | 'TENSOR_MAP_D': tensor_map_d, 99 | 'STREAM': torch.cuda.current_stream().cuda_stream, 100 | 'DEVICE_INDEX': out.device.index 101 | } 102 | 103 | # Generate, build and run the kernel 104 | code = FP8GemmRuntime.generate(kwargs) 105 | runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) 106 | runtime(**kwargs) 107 | 108 | 109 | def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], 110 | rhs: Tuple[torch.Tensor, torch.Tensor], 111 | out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: 112 | """ 113 | Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 114 | 115 | Requirements: 116 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 117 | RHS and RHS scaling factors are required to be transposed. 118 | The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, 119 | this function will do a transposing with a set of slow PyTorch operations. 120 | Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch 121 | should be separately transposed. 122 | 123 | Arguments: 124 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, 125 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. 126 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. 127 | The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 128 | out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. 129 | masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute 130 | in the i-th group. 131 | expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, 132 | correctly setting this value may lead to better performance. 133 | """ 134 | lhs, lhs_scales = lhs 135 | rhs, rhs_scales = rhs 136 | num_groups, m, k = lhs.shape 137 | num_groups_, n, k_ = rhs.shape 138 | num_groups__, m_, n_ = out.shape 139 | num_groups___ = masked_m.numel() 140 | 141 | # Type and shape checks 142 | assert num_groups == num_groups_ == num_groups__ == num_groups___ 143 | assert m == m_ and n == n_ and k == k_ 144 | assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 145 | assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) 146 | assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) 147 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 148 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 149 | assert out.dtype == torch.bfloat16 150 | assert masked_m.dtype == torch.int32 151 | assert lhs.is_contiguous() and rhs.is_contiguous() 152 | assert out.is_contiguous() and masked_m.is_contiguous() 153 | 154 | # LHS scales must be transposed for TMA load, but not for RHS scales 155 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 156 | assert rhs_scales.is_contiguous() 157 | 158 | # Auto-tuning with compilation 159 | num_sms = get_num_sms() 160 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 161 | expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) 162 | 163 | # Extra checks for TMA store 164 | if num_groups > 1 and m > block_m: 165 | assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' 166 | 167 | block_k = 128 168 | num_tma_threads = 128 169 | num_math_threads_per_group = 128 170 | 171 | tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups) 172 | tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups) 173 | tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) 174 | tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) 175 | 176 | kwargs = { 177 | # Templated arguments 178 | 'NUM_TMA_THREADS': num_tma_threads, 179 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 180 | 'M': m, 'N': n, 'K': k, 181 | 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 182 | 'SWIZZLE_D_MODE': smem_config[1], 183 | 'BLOCK_N_PADDING': smem_config[2], 184 | 'NUM_GROUPS': num_groups, 185 | 'NUM_STAGES': num_stages, 186 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 187 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 188 | 'GEMM_TYPE': GemmType.GroupedMasked, 189 | # Runtime arguments 190 | 'SCALES_B': rhs_scales, 191 | 'GROUPED_LAYOUT': masked_m, 192 | 'NUM_SMS': num_sms, 193 | 'SMEM_SIZE': smem_config[0], 194 | 'TENSOR_MAP_A': tensor_map_a, 195 | 'TENSOR_MAP_B': tensor_map_b, 196 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 197 | 'TENSOR_MAP_D': tensor_map_d, 198 | 'STREAM': torch.cuda.current_stream().cuda_stream, 199 | 'DEVICE_INDEX': out.device.index 200 | } 201 | 202 | # Generate, build and run the kernel 203 | code = FP8GemmRuntime.generate(kwargs) 204 | runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) 205 | runtime(**kwargs) 206 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/runtime.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import enum 4 | import torch 5 | import cuda.bindings.driver as cbd 6 | from typing import Any, Dict, Tuple 7 | 8 | from .utils import get_tma_aligned_size 9 | from ..jit.runtime import Runtime 10 | 11 | 12 | class GemmType(enum.Enum): 13 | Normal = 0 14 | GroupedContiguous = 1 15 | GroupedMasked = 2 16 | 17 | def __str__(self) -> str: 18 | return { 19 | 0: 'Normal', 20 | 1: 'GroupedContiguous', 21 | 2: 'GroupedMasked', 22 | }[self.value] 23 | 24 | 25 | tmap_type_map: Dict[Any, str] = { 26 | torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 27 | torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, 28 | torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, 29 | torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, 30 | torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 31 | torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, 32 | torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, 33 | torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, 34 | torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 35 | torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 36 | torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 37 | torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 38 | torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 39 | torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 40 | torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, 41 | } 42 | 43 | swizzle_type_map = { 44 | 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, 45 | 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, 46 | 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, 47 | 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, 48 | } 49 | 50 | 51 | def get_num_math_warpgroups(block_m: int) -> int: 52 | return 1 if block_m == 64 else 2 53 | 54 | 55 | def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: 56 | assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' 57 | return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads 58 | 59 | 60 | def make_2d_tma_copy_desc(t: torch.Tensor, 61 | gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t, 62 | smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], 63 | swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: 64 | tensor_dtype = tmap_type_map[t.dtype] 65 | res, tensor_map = cbd.cuTensorMapEncodeTiled( 66 | tensor_dtype, 67 | 2, 68 | t.data_ptr(), 69 | gmem_dims, 70 | (gmem_outer_stride,), 71 | smem_dims, 72 | (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), 73 | cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, 74 | swizzle_type, 75 | cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, 76 | cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, 77 | ) 78 | 79 | if res != cbd.CUresult.CUDA_SUCCESS: 80 | raise Exception(f'Failed to encode tensor map: {res}') 81 | return tensor_map 82 | 83 | 84 | def make_2d_tma_desc(t: torch.Tensor, 85 | gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int, 86 | smem_inner_dim: int, smem_outer_dim: int, 87 | swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: 88 | gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) 89 | smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) 90 | return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type) 91 | 92 | 93 | def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor, 94 | shape_m: int, shape_k: int, m_stride: int, 95 | block_m: int, block_k: int, 96 | num_groups: int) -> cbd.CUtensorMap: 97 | return make_2d_tma_desc(t, 98 | shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, 99 | block_k, block_m) 100 | 101 | 102 | def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor, 103 | shape_n: int, shape_k: int, n_stride: int, 104 | block_n: int, block_k: int, 105 | num_groups: int) -> cbd.CUtensorMap: 106 | return make_2d_tma_desc(t, 107 | shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride, 108 | block_k, block_n) 109 | 110 | 111 | def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor, 112 | shape_m: int, shape_n: int, m_stride: int, 113 | block_m: int, block_n: int, 114 | num_groups: int, 115 | swizzle_mode: int) -> cbd.CUtensorMap: 116 | # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` 117 | # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required 118 | return make_2d_tma_desc(t, 119 | shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, 120 | block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m, 121 | swizzle_type_map[swizzle_mode]) 122 | 123 | 124 | def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, 125 | shape_mn: int, shape_k: int, 126 | block_mn: int, block_k: int, 127 | num_groups: int) -> cbd.CUtensorMap: 128 | # Make TMA aligned to 16 bytes 129 | shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) 130 | return make_2d_tma_desc(t, 131 | shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn, 132 | block_mn, 1, 133 | cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) 134 | 135 | 136 | class FP8GemmRuntime(Runtime): 137 | def __init__(self, path: str) -> None: 138 | super().__init__(path) 139 | 140 | @staticmethod 141 | def generate(kwargs: Dict[str, Any]) -> str: 142 | code = f''' 143 | #ifdef __CUDACC_RTC__ 144 | #include 145 | #else 146 | #include 147 | #include 148 | #endif 149 | 150 | #include 151 | #include 152 | 153 | #include 154 | 155 | using namespace deep_gemm; 156 | 157 | static void __instantiate_kernel() {{ 158 | auto ptr = reinterpret_cast(&fp8_gemm_kernel< 159 | {kwargs['N']}, 160 | {kwargs['K']}, 161 | {kwargs['BLOCK_M']}, 162 | {kwargs['BLOCK_N']}, 163 | {kwargs['BLOCK_K']}, 164 | {kwargs['BLOCK_N_PADDING']}, 165 | {kwargs['SWIZZLE_D_MODE']}, 166 | {kwargs['NUM_GROUPS']}, 167 | {kwargs['NUM_STAGES']}, 168 | {kwargs['NUM_TMA_THREADS']}, 169 | {kwargs['NUM_MATH_THREADS_PER_GROUP']}, 170 | {kwargs['NUM_TMA_MULTICAST']}, 171 | {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, 172 | GemmType::{kwargs['GEMM_TYPE']} 173 | >); 174 | }}; 175 | ''' 176 | if int(os.getenv('DG_JIT_DEBUG', 0)): 177 | print(f'Generated FP8 GEMM code:\n{code}') 178 | return code 179 | 180 | # noinspection PyMethodOverriding 181 | @staticmethod 182 | def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: 183 | num_tma_threads = 128 184 | num_math_threads_per_group = 128 185 | 186 | result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, 187 | kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] 188 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' 189 | 190 | attr_val = cbd.CUlaunchAttributeValue() 191 | attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] 192 | attr_val.clusterDim.y = 1 193 | attr_val.clusterDim.z = 1 194 | attr = cbd.CUlaunchAttribute() 195 | attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION 196 | attr.value = attr_val 197 | 198 | config = cbd.CUlaunchConfig() 199 | config.numAttrs = 1 200 | config.attrs = [attr] 201 | config.gridDimX = kwargs['NUM_SMS'] 202 | config.gridDimY = 1 203 | config.gridDimZ = 1 204 | config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) 205 | config.blockDimY = 1 206 | config.blockDimZ = 1 207 | config.sharedMemBytes = kwargs['SMEM_SIZE'] 208 | config.hStream = kwargs['STREAM'] 209 | 210 | arg_values = ( 211 | kwargs['SCALES_B'].data_ptr(), 212 | kwargs['GROUPED_LAYOUT'].data_ptr(), 213 | kwargs['M'], 214 | kwargs['TENSOR_MAP_A'], 215 | kwargs['TENSOR_MAP_B'], 216 | kwargs['TENSOR_MAP_SCALES_A'], 217 | kwargs['TENSOR_MAP_D'], 218 | ) 219 | arg_types = ( 220 | ctypes.c_void_p, 221 | ctypes.c_void_p, 222 | ctypes.c_uint32, 223 | None, 224 | None, 225 | None, 226 | None, 227 | ) 228 | return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) 229 | 230 | 231 | class FP8WGradGemmRuntime(Runtime): 232 | def __init__(self, path: str) -> None: 233 | super().__init__(path) 234 | 235 | @staticmethod 236 | def generate(kwargs: Dict[str, Any]) -> str: 237 | code = f''' 238 | #ifdef __CUDACC_RTC__ 239 | #include 240 | #else 241 | #include 242 | #include 243 | #endif 244 | 245 | #include 246 | #include 247 | 248 | #include 249 | 250 | using namespace deep_gemm; 251 | 252 | static void __instantiate_kernel() {{ 253 | auto ptr = reinterpret_cast(&fp8_wgrad_gemm_kernel< 254 | {kwargs['M']}, 255 | {kwargs['N']}, 256 | {kwargs['BLOCK_M']}, 257 | {kwargs['BLOCK_N']}, 258 | {kwargs['BLOCK_K']}, 259 | {kwargs['NUM_STAGES']}, 260 | {kwargs['NUM_LAST_STAGES']}, 261 | {kwargs['NUM_TMA_THREADS']}, 262 | {kwargs['NUM_MATH_THREADS_PER_GROUP']}, 263 | {kwargs['NUM_TMA_MULTICAST']}, 264 | {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'} 265 | >); 266 | }}; 267 | ''' 268 | if int(os.getenv('DG_JIT_DEBUG', 0)): 269 | print(f'Generated FP8 WGrad GEMM code:\n{code}') 270 | return code 271 | 272 | # noinspection PyMethodOverriding 273 | @staticmethod 274 | def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: 275 | num_tma_threads = 128 276 | num_math_threads_per_group = 128 277 | 278 | result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, 279 | kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] 280 | assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' 281 | 282 | attr_val = cbd.CUlaunchAttributeValue() 283 | attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] 284 | attr_val.clusterDim.y = 1 285 | attr_val.clusterDim.z = 1 286 | attr = cbd.CUlaunchAttribute() 287 | attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION 288 | attr.value = attr_val 289 | 290 | config = cbd.CUlaunchConfig() 291 | config.numAttrs = 1 292 | config.attrs = [attr] 293 | config.gridDimX = kwargs['NUM_SMS'] 294 | config.gridDimY = 1 295 | config.gridDimZ = 1 296 | config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) 297 | config.blockDimY = 1 298 | config.blockDimZ = 1 299 | config.sharedMemBytes = kwargs['SMEM_SIZE'] 300 | config.hStream = kwargs['STREAM'] 301 | 302 | arg_values = ( 303 | kwargs['K'], 304 | kwargs['TENSOR_MAP_A'], 305 | kwargs['TENSOR_MAP_B'], 306 | kwargs['TENSOR_MAP_SCALES_A'], 307 | kwargs['TENSOR_MAP_SCALES_B'], 308 | kwargs['TENSOR_MAP_D'], 309 | ) 310 | arg_types = ( 311 | ctypes.c_uint32, 312 | None, 313 | None, 314 | None, 315 | None, 316 | None, 317 | ) 318 | return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) 319 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | _num_sms = None 4 | 5 | 6 | def set_num_sms(num_sms: int) -> None: 7 | """ 8 | Set the maximum SM count for all GEMM kernels to use. 9 | 10 | Arguments: 11 | num_sms: the desired maximum SM count for all GEMM kernels to use. 12 | """ 13 | global _num_sms 14 | assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count 15 | _num_sms = num_sms 16 | 17 | 18 | def get_num_sms() -> int: 19 | """ 20 | Get the current maximum limit of SM count for all GEMM kernels to use. 21 | If the count is never specified, the function will return the number of device SMs. 22 | 23 | Returns: 24 | Current maximum limit of SM count for all GEMM kernels to use. 25 | """ 26 | global _num_sms 27 | if _num_sms is None: 28 | _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count 29 | return _num_sms 30 | 31 | 32 | def ceil_div(x: int, y: int) -> int: 33 | """ 34 | Perform ceiling division of two integers. 35 | 36 | Args: 37 | x: the dividend. 38 | y: the divisor. 39 | 40 | Returns: 41 | The result of the ceiling division. 42 | """ 43 | return (x + y - 1) // y 44 | 45 | 46 | def get_m_alignment_for_contiguous_layout(): 47 | """ 48 | When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. 49 | Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well 50 | with GEMM block shape. 51 | 52 | Returns: 53 | Group-level alignment requirement for grouped contiguous layout, which is always 128. 54 | """ 55 | return 128 56 | 57 | 58 | def get_tma_aligned_size(x: int, element_size: int) -> int: 59 | """ 60 | Global memory address of TMA must be 16-byte aligned. 61 | Since we use column-major layout for the LHS scaling tensor, 62 | the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. 63 | 64 | Arguments: 65 | x: original M-axis shape of the LHS scaling tensor. 66 | element_size: element size of the LHS scaling tensor. 67 | 68 | Returns: 69 | M-axis shape of the LHS scaling tensor after padding. 70 | """ 71 | tma_alignment_bytes = 16 72 | assert tma_alignment_bytes % element_size == 0 73 | alignment = tma_alignment_bytes // element_size 74 | return ceil_div(x, alignment) * alignment 75 | 76 | 77 | def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. 80 | If the input tensor is already column-major layout and 16-byte aligned along the M axis 81 | (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. 82 | 83 | Arguments: 84 | x: usually the LHS scaling tensor in GEMM. 85 | 86 | Returns: 87 | The LHS scaling tensor of TMA-aligned transposed format. 88 | """ 89 | # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA 90 | assert x.dim() in (2, 3) 91 | remove_dim = False 92 | m, n = x.shape[-2], x.shape[-1] 93 | aligned_m = get_tma_aligned_size(m, x.element_size()) 94 | if x.dim() == 2: 95 | if x.stride(0) == 1 and x.stride(1) == aligned_m: 96 | return x 97 | x, remove_dim = x.unsqueeze(0), True 98 | 99 | b = x.shape[0] 100 | 101 | # The last kernel gives a column-major TMA aligned layout 102 | if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: 103 | return x.squeeze(0) if remove_dim else x 104 | 105 | # Normal layout requires transposing 106 | aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) 107 | aligned_x[:, :m, :] = x 108 | aligned_x = aligned_x[:, :m, :] 109 | return aligned_x.squeeze(0) if remove_dim else aligned_x 110 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/wgrad_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple 3 | 4 | from ..jit import build 5 | from .runtime import ( 6 | FP8WGradGemmRuntime, GemmType, 7 | make_2d_tma_a_desc, make_2d_tma_b_desc, 8 | make_2d_tma_d_desc, make_2d_tma_scales_desc) 9 | from .gemm import get_best_configs 10 | from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size 11 | 12 | 13 | def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 14 | rhs: Tuple[torch.Tensor, torch.Tensor], 15 | out: torch.Tensor): 16 | """ 17 | Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. 18 | Results will be accumulated into the output tensor. 19 | 20 | Requirements: 21 | LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. 22 | The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. 23 | RHS and RHS scaling factors are required to be transposed. 24 | The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format. 25 | If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. 26 | 27 | Arguments: 28 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, 29 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. 30 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, 31 | the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. 32 | out: the FP32 output tensor of shape `[m, n]`, which will be accumulated. 33 | """ 34 | lhs, lhs_scales = lhs 35 | rhs, rhs_scales = rhs 36 | m, k = lhs.shape 37 | n, k_ = rhs.shape 38 | m_, n_ = out.shape 39 | 40 | # Type and shape checks 41 | assert m == m_ and n == n_ and k == k_ 42 | assert n > 0 and m > 0 43 | assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m) 44 | assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n) 45 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 46 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 47 | assert out.dtype == torch.float 48 | assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 49 | 50 | # LHS and RHS scales must be transposed for TMA load 51 | # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels 52 | def get_valid_scales(scales: torch.Tensor, mn: int): 53 | if scales.shape == (ceil_div(k, 128), mn): 54 | # For k-grouped GEMMs 55 | scales = scales.permute(1, 0) 56 | assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn 57 | else: 58 | scales = get_col_major_tma_aligned_tensor(scales) 59 | return scales 60 | 61 | lhs_scales = get_valid_scales(lhs_scales, m) 62 | rhs_scales = get_valid_scales(rhs_scales, n) 63 | 64 | # Do nothing if `k` is zero 65 | if k == 0: 66 | return 67 | 68 | # K must be aligned to 128 69 | aligned_k = ceil_div(k, 128) * 128 70 | 71 | # Auto-tuning with compilation 72 | num_sms = get_num_sms() 73 | num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( 74 | m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) 75 | num_last_stages = ceil_div(k, 128) % num_stages 76 | block_k = 128 77 | num_tma_threads = 128 78 | num_math_threads_per_group = 128 79 | 80 | tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) 81 | tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) 82 | tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) 83 | tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) 84 | tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1) 85 | 86 | kwargs = { 87 | # Templated arguments 88 | 'GEMM_TYPE': GemmType.Normal, 89 | 'NUM_TMA_THREADS': num_tma_threads, 90 | 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 91 | 'M': m, 'N': n, 'K': aligned_k, 92 | 'NUM_GROUPS': 1, 93 | 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 94 | 'NUM_STAGES': num_stages, 95 | 'NUM_LAST_STAGES': num_last_stages, 96 | 'NUM_TMA_MULTICAST': tma_multicast_config[0], 97 | 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 98 | # Runtime arguments 99 | 'NUM_SMS': num_sms, 100 | 'SMEM_SIZE': smem_config[0], 101 | 'TENSOR_MAP_A': tensor_map_a, 102 | 'TENSOR_MAP_B': tensor_map_b, 103 | 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 104 | 'TENSOR_MAP_SCALES_B': tensor_map_scales_b, 105 | 'TENSOR_MAP_D': tensor_map_d, 106 | 'STREAM': torch.cuda.current_stream().cuda_stream, 107 | 'DEVICE_INDEX': out.device.index 108 | } 109 | 110 | # Generate, build and run the kernel 111 | code = FP8WGradGemmRuntime.generate(kwargs) 112 | runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) 113 | runtime(**kwargs) 114 | 115 | 116 | def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 117 | rhs: Tuple[torch.Tensor, torch.Tensor], 118 | out: torch.Tensor, 119 | batch_sizes: List[int]): 120 | """ 121 | Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. 122 | Results will be accumulated into the output tensor. 123 | 124 | Requirements: 125 | This function handles multiple batches with varying k-dimensions, processing each batch sequentially. 126 | Each batch's LHS, RHS, and output tensors must be contiguous. 127 | The RHS and RHS scaling factors are required to be transposed. 128 | The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format. 129 | 130 | Arguments: 131 | lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, 132 | and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. 133 | The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, 134 | representing the per-128-channel scaling factors. 135 | rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, 136 | and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. 137 | The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, 138 | representing the per-128-channel scaling factors. 139 | out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. 140 | batch_sizes: A list of integers specifying the k-dimension for each batch. 141 | """ 142 | lhs, lhs_scales = lhs[0].view(-1), lhs[1] 143 | rhs, rhs_scales = rhs[0].view(-1), rhs[1] 144 | num_batches, m, n = out.shape 145 | 146 | lhs_offset, rhs_offset, scales_offset = 0, 0, 0 147 | 148 | for i in range(num_batches): 149 | k = batch_sizes[i] 150 | lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) 151 | rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) 152 | lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] 153 | rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] 154 | wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) 155 | 156 | lhs_offset += m * k 157 | rhs_offset += n * k 158 | scales_offset += ceil_div(k, 128) 159 | -------------------------------------------------------------------------------- /deep_gemm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def bench(fn, num_warmups: int = 5, num_tests: int = 10, 9 | high_precision: bool = False): 10 | # Flush L2 cache with 256 MB data 11 | torch.cuda.synchronize() 12 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') 13 | cache.zero_() 14 | 15 | # Warmup 16 | for _ in range(num_warmups): 17 | fn() 18 | 19 | # Add a large kernel to eliminate the CPU launch overhead 20 | if high_precision: 21 | x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 22 | y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 23 | x @ y 24 | 25 | # Testing 26 | start_event = torch.cuda.Event(enable_timing=True) 27 | end_event = torch.cuda.Event(enable_timing=True) 28 | start_event.record() 29 | for i in range(num_tests): 30 | fn() 31 | end_event.record() 32 | torch.cuda.synchronize() 33 | 34 | return start_event.elapsed_time(end_event) / num_tests 35 | 36 | 37 | class empty_suppress: 38 | def __enter__(self): 39 | return self 40 | 41 | def __exit__(self, *_): 42 | pass 43 | 44 | 45 | class suppress_stdout_stderr: 46 | def __enter__(self): 47 | self.outnull_file = open(os.devnull, 'w') 48 | self.errnull_file = open(os.devnull, 'w') 49 | 50 | self.old_stdout_fileno_undup = sys.stdout.fileno() 51 | self.old_stderr_fileno_undup = sys.stderr.fileno() 52 | 53 | self.old_stdout_fileno = os.dup(sys.stdout.fileno()) 54 | self.old_stderr_fileno = os.dup(sys.stderr.fileno()) 55 | 56 | self.old_stdout = sys.stdout 57 | self.old_stderr = sys.stderr 58 | 59 | os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) 60 | os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) 61 | 62 | sys.stdout = self.outnull_file 63 | sys.stderr = self.errnull_file 64 | return self 65 | 66 | def __exit__(self, *_): 67 | sys.stdout = self.old_stdout 68 | sys.stderr = self.old_stderr 69 | 70 | os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) 71 | os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 72 | 73 | os.close(self.old_stdout_fileno) 74 | os.close(self.old_stderr_fileno) 75 | 76 | self.outnull_file.close() 77 | self.errnull_file.close() 78 | 79 | 80 | def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, 81 | trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, 82 | with_multiple_kernels: bool = False): 83 | # Conflict with Nsight Systems 84 | using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) 85 | 86 | # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle 87 | flush_l2_size = int(8e9 // 4) 88 | 89 | # For some auto-tuning kernels with prints 90 | fn() 91 | 92 | # Profile 93 | suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress 94 | with suppress(): 95 | schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None 96 | profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() 97 | with profiler: 98 | for i in range(2): 99 | # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead 100 | if barrier_comm_profiling: 101 | lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 102 | rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 103 | lhs @ rhs 104 | dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) 105 | for _ in range(num_tests): 106 | if flush_l2: 107 | torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() 108 | fn() 109 | 110 | if not using_nsys: 111 | profiler.step() 112 | 113 | # Return 1 if using Nsight Systems 114 | if using_nsys: 115 | return 1 116 | 117 | # Parse the profiling table 118 | assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) 119 | is_tupled = isinstance(kernel_names, tuple) 120 | prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') 121 | kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names 122 | assert all([isinstance(name, str) for name in kernel_names]) 123 | if not with_multiple_kernels: 124 | for name in kernel_names: 125 | assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' 126 | 127 | # Save chrome traces 128 | if trace_path is not None: 129 | profiler.export_chrome_trace(trace_path) 130 | 131 | # Return average kernel times 132 | units = {'ms': 1e3, 'us': 1e6} 133 | kernel_times = [] 134 | for name in kernel_names: 135 | total_time = 0 136 | total_num = 0 137 | for line in prof_lines: 138 | if name in line: 139 | time_str = line.split()[-2] 140 | num_str = line.split()[-1] 141 | for unit, scale in units.items(): 142 | if unit in time_str: 143 | total_time += float(time_str.replace(unit, '')) / scale * int(num_str) 144 | total_num += int(num_str) 145 | break 146 | kernel_times.append(total_time / total_num) 147 | 148 | return tuple(kernel_times) if is_tupled else kernel_times[0] 149 | 150 | 151 | def calc_diff(x, y): 152 | x, y = x.double(), y.double() 153 | denominator = (x * x + y * y).sum() 154 | sim = 2 * (x * y).sum() / denominator 155 | return 1 - sim 156 | 157 | 158 | def count_bytes(tensors): 159 | total = 0 160 | for t in tensors: 161 | if isinstance(t, tuple): 162 | total += count_bytes(t) 163 | else: 164 | total += t.numel() * t.element_size() 165 | return total 166 | -------------------------------------------------------------------------------- /figures/design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepGEMM/8dfa3298274bfe6b242f6f8a3e6f3eff2707dd9f/figures/design.png -------------------------------------------------------------------------------- /indexing/main.cu: -------------------------------------------------------------------------------- 1 | #include "deep_gemm/fp8_gemm.cuh" 2 | #include "deep_gemm/fp8_wgrad_gemm.cuh" 3 | 4 | using namespace deep_gemm; 5 | 6 | int main() { 7 | return 0; 8 | } 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | import shutil 4 | import subprocess 5 | from setuptools.command.build_py import build_py 6 | from setuptools.command.develop import develop 7 | 8 | current_dir = os.path.dirname(os.path.realpath(__file__)) 9 | jit_include_dirs = ('deep_gemm/include/deep_gemm', ) 10 | third_party_include_dirs = ( 11 | 'third-party/cutlass/include/cute', 12 | 'third-party/cutlass/include/cutlass', 13 | ) 14 | 15 | 16 | class PostDevelopCommand(develop): 17 | def run(self): 18 | develop.run(self) 19 | self.make_jit_include_symlinks() 20 | 21 | @staticmethod 22 | def make_jit_include_symlinks(): 23 | # Make symbolic links of third-party include directories 24 | for d in third_party_include_dirs: 25 | dirname = d.split('/')[-1] 26 | src_dir = f'{current_dir}/{d}' 27 | dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' 28 | assert os.path.exists(src_dir) 29 | if os.path.exists(dst_dir): 30 | assert os.path.islink(dst_dir) 31 | os.unlink(dst_dir) 32 | os.symlink(src_dir, dst_dir, target_is_directory=True) 33 | 34 | 35 | class CustomBuildPy(build_py): 36 | def run(self): 37 | # First, prepare the include directories 38 | self.prepare_includes() 39 | 40 | # Then run the regular build 41 | build_py.run(self) 42 | 43 | def prepare_includes(self): 44 | # Create temporary build directory instead of modifying package directory 45 | build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') 46 | os.makedirs(build_include_dir, exist_ok=True) 47 | 48 | # Copy third-party includes to the build directory 49 | for d in third_party_include_dirs: 50 | dirname = d.split('/')[-1] 51 | src_dir = os.path.join(current_dir, d) 52 | dst_dir = os.path.join(build_include_dir, dirname) 53 | 54 | # Remove existing directory if it exists 55 | if os.path.exists(dst_dir): 56 | shutil.rmtree(dst_dir) 57 | 58 | # Copy the directory 59 | shutil.copytree(src_dir, dst_dir) 60 | 61 | 62 | if __name__ == '__main__': 63 | # noinspection PyBroadException 64 | try: 65 | cmd = ['git', 'rev-parse', '--short', 'HEAD'] 66 | revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() 67 | except: 68 | revision = '' 69 | 70 | setuptools.setup( 71 | name='deep_gemm', 72 | version='1.0.0' + revision, 73 | packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], 74 | package_data={ 75 | 'deep_gemm': [ 76 | 'include/deep_gemm/*', 77 | 'include/cute/**/*', 78 | 'include/cutlass/**/*', 79 | ] 80 | }, 81 | cmdclass={ 82 | 'develop': PostDevelopCommand, 83 | 'build_py': CustomBuildPy, 84 | }, 85 | ) 86 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | # PyTorch has its own NVRTC, which may have a lower version than the system 2 | # So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch 3 | import cuda.bindings.nvrtc as nvrtc 4 | print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') 5 | 6 | import random 7 | import torch 8 | from typing import List, Tuple 9 | 10 | import deep_gemm 11 | from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor 12 | from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout 13 | 14 | 15 | def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 16 | assert x.dim() == 2 17 | m, n = x.shape 18 | pad_size = (128 - (n % 128)) % 128 19 | x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x 20 | x_view = x.view(m, -1, 128) 21 | x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) 22 | fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) 23 | return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) 24 | 25 | 26 | def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 27 | assert x.dim() == 2 28 | m, n = x.shape 29 | x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) 30 | x_padded[:m, :n] = x 31 | x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) 32 | x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) 33 | x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) 34 | return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) 35 | 36 | 37 | def construct(m: int, k: int, n: int) -> \ 38 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 39 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 40 | y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) 41 | out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) 42 | ref_out = x @ y.t() 43 | 44 | x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) 45 | # Transpose earlier so that the testing will not trigger transposing kernels 46 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 47 | return x_fp8, y_fp8, out, ref_out 48 | 49 | 50 | def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ 51 | Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: 52 | alignment = get_m_alignment_for_contiguous_layout() 53 | group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] 54 | m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) 55 | 56 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 57 | y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) 58 | m_indices = torch.empty(m, device='cuda', dtype=torch.int32) 59 | out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) 60 | ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) 61 | 62 | start = 0 63 | for i, group_m in enumerate(group_ms): 64 | actual_end = start + group_m 65 | aligned_end = start + ceil_div(group_m, alignment) * alignment 66 | m_indices[start:actual_end] = i 67 | m_indices[actual_end:aligned_end] = -1 68 | ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() 69 | start = aligned_end 70 | ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) 71 | 72 | assert m % 4 == 0, f'TMA alignment error: {m}' 73 | x_fp8 = per_token_cast_to_fp8(x) 74 | y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) 75 | for i in range(num_groups): 76 | y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) 77 | 78 | return m, x_fp8, y_fp8, m_indices, out, ref_out 79 | 80 | 81 | def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \ 82 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: 83 | x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) 84 | y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) 85 | out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) 86 | ref_out = torch.einsum('gmk,gnk->gmn', x, y) 87 | 88 | assert max_m % 4 == 0, f'TMA alignment error: {max_m}' 89 | x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float)) 90 | y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) 91 | for i in range(num_groups): 92 | x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) 93 | y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) 94 | 95 | # Transpose earlier so that the testing will not trigger transposing kernels 96 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 97 | 98 | # Construct mask 99 | masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) 100 | for j in range(num_groups): 101 | masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) 102 | assert masked_m.amax().item() <= max_m 103 | return x_fp8, y_fp8, masked_m, out, ref_out 104 | 105 | 106 | def construct_wgrad(m: int, k: int, n: int) -> \ 107 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: 108 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 109 | y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) 110 | residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 111 | out = residual.clone() 112 | ref_out = residual + (x.float() @ y.float().t()) 113 | 114 | x_fp8 = per_token_cast_to_fp8(x) 115 | y_fp8 = per_token_cast_to_fp8(y) 116 | 117 | # NOTES: please do inplace add on the `out` later 118 | return x_fp8, y_fp8, residual, out, ref_out 119 | 120 | 121 | def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ 122 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: 123 | num_groups, total_k = len(k_sizes), sum(k_sizes) 124 | 125 | x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) 126 | y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) 127 | out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) 128 | ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) 129 | 130 | # Fill tensors with data and compute reference output 131 | x_offset, y_offset = 0, 0 132 | for idx, k in enumerate(k_sizes): 133 | x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 134 | y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) 135 | 136 | x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) 137 | y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) 138 | ref_out[idx] = x_chunk.float() @ y_chunk.float().t() 139 | 140 | x_offset += m * k 141 | y_offset += n * k 142 | 143 | x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) 144 | y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) 145 | 146 | total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) 147 | x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) 148 | y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) 149 | 150 | # Cast to FP8 and prepare scale factors 151 | x_offset, y_offset, scale_offset = 0, 0, 0 152 | for k in k_sizes: 153 | x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) 154 | y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) 155 | 156 | x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) 157 | y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) 158 | 159 | num_scales = ceil_div(k, 128) 160 | x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) 161 | y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) 162 | 163 | x_offset += m * k 164 | y_offset += n * k 165 | scale_offset += num_scales 166 | 167 | return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes 168 | 169 | 170 | def test_gemm() -> None: 171 | print('Testing GEMM:') 172 | for m in (64, 128, 4096): 173 | for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: 174 | x_fp8, y_fp8, out, ref_out = construct(m, k, n) 175 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 176 | diff = calc_diff(out, ref_out) 177 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' 178 | 179 | # noinspection PyShadowingNames 180 | def test_func(): 181 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 182 | 183 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 184 | print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' 185 | f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' 186 | f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') 187 | print() 188 | 189 | 190 | def test_m_grouped_gemm_contiguous() -> None: 191 | print('Testing grouped contiguous GEMM:') 192 | 193 | for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), 194 | (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), 195 | (32, 256, 7168, 4096), (32, 256, 2048, 7168)): 196 | # NOTES: we should mask the unfilled part before calculating difference 197 | m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) 198 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 199 | out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) 200 | diff = calc_diff(out, ref_out) 201 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' 202 | 203 | # noinspection PyShadowingNames 204 | def test_func(): 205 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 206 | 207 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 208 | valid_m = (m_indices != -1).sum().item() 209 | print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 210 | f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' 211 | f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') 212 | print() 213 | 214 | 215 | def test_m_grouped_gemm_masked() -> None: 216 | print('Testing grouped masked GEMM:') 217 | 218 | for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): 219 | for k, n in ((7168, 4096), (2048, 7168), ): 220 | # Test correctness 221 | for i in range(10): 222 | x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) 223 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) 224 | for j in range(num_groups): 225 | diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) 226 | assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' 227 | 228 | # noinspection PyShadowingNames 229 | def test_func(): 230 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) 231 | 232 | # Test performance with fixed shapes 233 | # noinspection PyUnboundLocalVariable 234 | valid_m = masked_m.sum().item() 235 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 236 | print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 237 | f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' 238 | f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') 239 | print() 240 | 241 | 242 | def test_wgrad_gemm(): 243 | print('Testing weight gradient GEMM:') 244 | 245 | for k in (4096, 8192): 246 | for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): 247 | # Test correctness 248 | x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) 249 | deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) 250 | diff = calc_diff(out, ref_out) 251 | assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' 252 | 253 | # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) 254 | x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) 255 | 256 | # noinspection PyShadowingNames 257 | def test_func(): 258 | deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) 259 | 260 | t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) 261 | print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' 262 | f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' 263 | f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') 264 | print() 265 | 266 | 267 | def test_k_grouped_wgrad_gemm(): 268 | print('Testing grouped weight gradient GEMM:') 269 | 270 | for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): 271 | for m, n in ((7168, 4096), (2048, 7168)): 272 | # Vary k sizes around base_k 273 | k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] 274 | k_sizes.append(base_k * num_groups - sum(k_sizes)) 275 | 276 | # Test correctness 277 | x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) 278 | deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) 279 | 280 | for idx in range(num_groups): 281 | diff = calc_diff(out[idx], ref_out[idx]) 282 | assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' 283 | 284 | # Construct new tensors to avoid L2 cache acceleration 285 | x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) 286 | total_k = sum(k_sizes) 287 | 288 | def test_func(): 289 | deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) 290 | 291 | t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups 292 | print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' 293 | f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' 294 | f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') 295 | print() 296 | 297 | 298 | if __name__ == '__main__': 299 | torch.backends.cuda.matmul.allow_tf32 = True 300 | torch.backends.cudnn.allow_tf32 = True 301 | torch.manual_seed(0) 302 | random.seed(0) 303 | 304 | print('Library path:') 305 | print(f' > {deep_gemm.__path__}\n') 306 | 307 | test_gemm() 308 | test_m_grouped_gemm_contiguous() 309 | test_m_grouped_gemm_masked() 310 | 311 | test_wgrad_gemm() 312 | test_k_grouped_wgrad_gemm() 313 | -------------------------------------------------------------------------------- /tests/test_jit.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import torch 4 | import cuda.bindings.driver as cbd 5 | from typing import Any, Dict 6 | 7 | from deep_gemm import jit 8 | 9 | # Essential debugging staffs 10 | os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') 11 | os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') 12 | 13 | 14 | class VectorAddRuntime(jit.Runtime): 15 | def __init__(self, path: str) -> None: 16 | super().__init__(path) 17 | 18 | @staticmethod 19 | def generate(kwargs: Dict[str, Any]) -> str: 20 | return f""" 21 | #ifdef __CUDACC_RTC__ 22 | #include 23 | #else 24 | #include 25 | #endif 26 | 27 | #include 28 | #include 29 | 30 | template 31 | __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ 32 | uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; 33 | if (i < n) {{ 34 | c[i] = a[i] + b[i]; 35 | }} 36 | }} 37 | 38 | static void __instantiate_kernel() {{ 39 | auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); 40 | }} 41 | """ 42 | 43 | # noinspection PyShadowingNames,PyMethodOverriding 44 | @staticmethod 45 | def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: 46 | assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape 47 | assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device 48 | assert kwargs['A'].dim() == 1 49 | 50 | config = cbd.CUlaunchConfig() 51 | config.gridDimX = (kwargs['A'].numel() + 127) // 128 52 | config.gridDimY = 1 53 | config.gridDimZ = 1 54 | config.blockDimX = 128 55 | config.blockDimY = 1 56 | config.blockDimZ = 1 57 | config.hStream = kwargs['STREAM'] 58 | 59 | arg_values = ( 60 | kwargs['A'].data_ptr(), 61 | kwargs['B'].data_ptr(), 62 | kwargs['C'].data_ptr(), 63 | kwargs['A'].numel(), 64 | ) 65 | arg_types = ( 66 | ctypes.c_void_p, 67 | ctypes.c_void_p, 68 | ctypes.c_void_p, 69 | ctypes.c_uint32, 70 | ) 71 | 72 | return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] 73 | 74 | 75 | if __name__ == '__main__': 76 | print('Generated code:') 77 | kwargs = {'T': 'float'} 78 | code = VectorAddRuntime.generate(kwargs) 79 | print(code) 80 | print() 81 | 82 | for compiler_name in ('NVCC', 'NVRTC'): 83 | # Get compiler 84 | compiler_cls = getattr(jit, f'{compiler_name}Compiler') 85 | print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') 86 | 87 | # Build 88 | print('Building ...') 89 | func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs) 90 | 91 | # Run and check 92 | a = torch.randn((1024, ), dtype=torch.float32, device='cuda') 93 | b = torch.randn((1024, ), dtype=torch.float32, device='cuda') 94 | c = torch.empty_like(a) 95 | ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) 96 | assert ret == cbd.CUresult.CUDA_SUCCESS, ret 97 | torch.testing.assert_close(c, a + b) 98 | print(f'JIT test for {compiler_name} passed\n') 99 | --------------------------------------------------------------------------------