├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── deep_gemm ├── __init__.py ├── include │ ├── deep_gemm │ │ ├── fp8_gemm.cuh │ │ ├── mma_utils.cuh │ │ ├── reorder_b.cuh │ │ ├── scheduler.cuh │ │ ├── tma_utils.cuh │ │ └── utils.cuh │ └── l2_torch_alloc │ │ └── sideaware.cu ├── jit │ ├── __init__.py │ ├── compiler.py │ ├── interleave_ffma.py │ ├── runtime.py │ └── template.py ├── jit_kernels │ ├── __init__.py │ ├── gemm.py │ ├── m_grouped_gemm.py │ ├── preprocess.py │ ├── sideaware.py │ ├── tuner.py │ └── utils.py └── utils.py ├── figures └── design.png ├── setup.py └── tests ├── test_core.py └── test_jit.py /.gitignore: -------------------------------------------------------------------------------- 1 | cmake-build-* 2 | .idea 3 | .DS_Store 4 | build 5 | dist 6 | *.egg-info 7 | *.pyc 8 | 9 | # Third-party links created by `setup.py develop` 10 | deep_gemm/include/cute 11 | deep_gemm/include/cutlass 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/cutlass"] 2 | path = third-party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeeperGEMM 2 | 3 | New version, it's pretty cool (if you consider inline PTX cool). 4 | 5 | L2 Side Optimization for B Matrix is now based on cuda-side-boost: https://github.com/ademeure/cuda-side-boost 6 | 7 | Not planning to support this any further or provide in-depth documentation due to lack of time. There are also optimizations on the DeepGEMM repo that'd be hard to integrate - they've definitely closed the gap, but DeeperGEMM remains faster for many shapes. 8 | 9 | If you have a specific question, feel free to get in touch though! (unless it's about the inline PTX to load the B matrix because it was provided to me directly by god, or maybe the devil, I'm not sure). 10 | 11 | # Previous Release Notes 12 | 13 | - L2 cache side awareness for B matrix (store tiles of B on the same side of the L2 as the SMs processing that part of the matrix). 14 | - Improved TMA load pipelining that doesn’t rely on unrolling, so there are no bubbles between tiles, and reuse memory for D. 15 | - Optimized tile output overlapped with 1st matmuls of next tile, and shared memory padding reducing bank conflicts by 8x. 16 | - Improved GEMM pipelining with dual accumulators so every warpgroup has a GEMM running nearly all the time (instead of relying on inter-warpgroup parallelism). 17 | - Optional support for 256x block size (instead of 128x) that halves the number of FMAs (probably okay for inference?) with highly optimized pipelining between GEMMs. 18 | - Lots of small optimzations adding up to a lot of performance 19 | 20 | UPDATE: I'm planning to release an updated & bugfixed version in a few days with a custom PyTorch memory allocator that will significantly reduce the overhead of "L2 side awareness" and provide a simple way to create L2 aware elementwise 1:1 kernels. 21 | 22 | I'll also write explanations of the existing optimisations, feel free to let me know if you have any thoughts or things that don't make sense! 23 | 24 | ## DeeperGEMM 25 | 26 | Deepseek’s DeepGEMM delivers great FP8 matrix multiplication performance with fine-grained scaling on NVIDIA Hopper GPUs using custom PTX, something that seemed nearly impossible previously given the lack of hardware support for micro-tensor scaling (added in Blackwell). 27 | 28 | It’s very fast… so let’s make it *even way faster*! 29 | 30 | ## Architecture 31 | 32 | See previous versions of README.md for a small amount of not very useful additional information. 33 | -------------------------------------------------------------------------------- /deep_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import jit 4 | from .jit_kernels import ( 5 | gemm_fp8_fp8_bf16_nt, 6 | preprocess_reorder_b, 7 | preprocess_reorder_b_grouped, 8 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 9 | m_grouped_gemm_fp8_fp8_bf16_nt_masked, 10 | ceil_div, 11 | set_num_sms, get_num_sms, 12 | get_col_major_tma_aligned_tensor, 13 | get_m_alignment_for_contiguous_layout, 14 | sideaware_init, sideaware_enabled, sideaware_compile, 15 | sideaware_torch_side_index, sideaware_gpu_side_index, sideaware_cpu_side_index, 16 | sideaware_info, sideaware_info_raw, 17 | ) 18 | from .utils import bench, bench_kineto, calc_diff 19 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/fp8_gemm.cuh: -------------------------------------------------------------------------------- 1 | 2 | 3 | #pragma clang diagnostic push 4 | #pragma clang diagnostic ignored "-Wunknown-attributes" 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "mma_utils.cuh" 15 | #include "scheduler.cuh" 16 | #include "tma_utils.cuh" 17 | #include "utils.cuh" 18 | 19 | namespace deep_gemm { 20 | 21 | // TODO - these settings shouldn't be here. 22 | constexpr bool DOUBLE_PUMP = true; // todo - figure out how we can make this *always* faster (not just usually so...) 23 | constexpr bool DP_SCALE_256 = true; // todo - assumes A/B scales are always the same for 2 blocks, need test data here 24 | 25 | constexpr int MAX_SM = 132; 26 | constexpr int PADDING_N = 16; // padding for D to avoid STSM bank conflicts (todo - clearer conditions etc.) 27 | constexpr int NUM_TILES_INITIAL = 32; // calxulate m/n for INITIAL tiles in parallel in prologue 28 | constexpr int NUM_TILES_STORAGE = 64; // 1 more every time we load B scales in a round-robin buffer 29 | 30 | // Register reconfigurations (24/208 is OK for most shapes but slower for 4096 x 24576 x 1536 on CUDA 12.8?) 31 | constexpr int kNumTMARegisters = 32; 32 | constexpr int kNumMathRegisters = 224; 33 | 34 | // Highly situational, when we flush L2 by writing new data (current default) rather than reading existing data, 35 | // we can get >10% for M=64/128, but that seems unrealistic (same reason why L2 opt might help less in production) 36 | #define POLICY_TINY_A_READ_B "createpolicy.fractional.L2::evict_first.L2::evict_unchanged.b64 policy, 1.0;\n" 37 | #define POLICY_TINY_A_WRITE_D "createpolicy.fractional.L2::evict_unchanged.L2::evict_unchanged.b64 policy, 1.0;\n" 38 | 39 | #define POLICY_BIG_A_READ_B "createpolicy.fractional.L2::evict_unchanged.L2::evict_unchanged.b64 policy, 1.0;\n" 40 | #define POLICY_BIG_A_WRITE_D "createpolicy.fractional.L2::evict_first.L2::evict_unchanged.b64 policy, 1.0;\n" 41 | 42 | enum class Layout { 43 | RowMajor, 44 | ColMajor 45 | }; 46 | 47 | template 48 | __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { 49 | DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); 50 | return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; 51 | } 52 | 53 | typedef struct { 54 | uint8_t sm_side_and_idx[MAX_SM]; 55 | } param_side_index_t; 56 | 57 | template 63 | __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) 64 | fp8_gemm_kernel(__nv_bfloat16* gmem_d, __nv_fp8_e4m3* gmem_b, float* scales_b, int* grouped_layout, int* zeroed_scratch, 65 | uint32_t shape_m, 66 | const __grid_constant__ CUtensorMap tensor_map_a, 67 | const __grid_constant__ CUtensorMap tensor_map_b, 68 | const __grid_constant__ CUtensorMap tensor_map_scales_a, 69 | const __grid_constant__ CUtensorMap tensor_map_d, 70 | const __grid_constant__ CUtensorMap tensor_map_d_padded, 71 | int block_idx_base, int aggregate_grid_size, 72 | __grid_constant__ const param_side_index_t sideaware) { 73 | #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) 74 | // 75 | constexpr bool L2_SIDE_OPTIMIZATION = l2_optimization && (kGemmType != GemmType::GroupedMasked); 76 | constexpr uint32_t SHAPE_N_HALF = SHAPE_N / 2; 77 | constexpr uint32_t CLUSTER_BLOCK_N = BLOCK_N * (kTMAMulticastEnabled ? 2 : 1); 78 | constexpr uint32_t SHAPE_N_LOWER = ((SHAPE_N_HALF + CLUSTER_BLOCK_N - 1) / CLUSTER_BLOCK_N) * CLUSTER_BLOCK_N; 79 | constexpr uint32_t SHAPE_N_UPPER = SHAPE_N - SHAPE_N_LOWER; 80 | constexpr uint32_t SHAPE_N_MAX = L2_SIDE_OPTIMIZATION ? SHAPE_N_LOWER : SHAPE_N; 81 | 82 | // Scaling checks 83 | DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); 84 | DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); 85 | 86 | // The compiler can optionally optimize everything if shape_m is known at compile time 87 | shape_m = FORCED_M ? FORCED_M : shape_m; 88 | // TODO: Re-enable hybrid cluster support 89 | //if constexpr (!kTMAMulticastEnabled) { 90 | //aggregate_grid_size = gridDim.x; 91 | block_idx_base = 0; 92 | //} 93 | 94 | // Types 95 | using WGMMA = typename FP8MMASelector::type; 96 | using Barrier = cutlass::arch::ClusterTransactionBarrier; 97 | 98 | constexpr uint32_t PADDING = (BLOCK_N == 64 || BLOCK_N == 96 || BLOCK_N == 128) ? PADDING_N : 0; 99 | constexpr uint32_t BLOCK_N_PADDED = BLOCK_N + PADDING; 100 | 101 | // Shared memory 102 | static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); 103 | static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); 104 | static constexpr uint32_t SMEM_D_SIZE_PADDED = BLOCK_M * BLOCK_N_PADDED * sizeof(__nv_bfloat16); 105 | static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); 106 | static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); 107 | static constexpr uint32_t SMEM_AB_SIZE_PER_STAGE_RAW = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; 108 | static constexpr uint32_t SMEM_AB_SIZE_PER_STAGE = SMEM_D_SIZE_PADDED > SMEM_AB_SIZE_PER_STAGE_RAW ? SMEM_D_SIZE_PADDED : SMEM_AB_SIZE_PER_STAGE_RAW; 109 | static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); 110 | static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); 111 | static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); 112 | 113 | // Configs 114 | constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; 115 | constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); 116 | constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; 117 | constexpr uint32_t kNumMathWarps = kNumMathThreads / 32; 118 | constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); 119 | 120 | // Align to 1024 bytes for swizzle-128B 121 | extern __shared__ __align__(1024) uint8_t smem_buffer[]; 122 | DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); 123 | 124 | // Fill shared memory *base* pointers 125 | // Everything is contiguous in memory, so we can do address calculations instead storing loads of pointers 126 | // this is needed for performance without fully unrolling the loops (otherwise we see LDL/STL for array indexing) 127 | // A and B are interleaved (A[0], B[0], A[1], B[1], ...) so they can be reused for D, everything else is contiguous 128 | __nv_fp8_e4m3* smem_a_base = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer); 129 | __nv_fp8_e4m3* smem_b_base = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_A_SIZE_PER_STAGE); 130 | float* smem_scales_a_base = reinterpret_cast(smem_buffer + kNumStages * (SMEM_AB_SIZE_PER_STAGE)); 131 | float* smem_scales_b_base = reinterpret_cast(smem_buffer + kNumStages * (SMEM_AB_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); 132 | constexpr int kNumScalesB = SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2); 133 | 134 | // Fill barriers (base pointers only) 135 | DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); 136 | DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers"); 137 | auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b_base) + (2*SMEM_SCALES_B_SIZE)); 138 | auto full_barriers_base = barrier_start_ptr; 139 | auto empty_barriers_base = barrier_start_ptr + kNumStages; 140 | auto full_barrier_scales_b_base = barrier_start_ptr + kNumStages * 2; 141 | auto empty_barrier_scales_b_base = barrier_start_ptr + kNumStages * 2 + 2; // double-buffered 142 | uint4* smem_tile_scheduling = reinterpret_cast(empty_barrier_scales_b_base + 2); 143 | 144 | const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); 145 | const uint32_t lane_idx = get_lane_id(); 146 | uint32_t cluster_size; // get cluster_nctaid 147 | asm volatile("mov.u32 %0, %cluster_nctaid.x;" : "=r"(cluster_size)); 148 | 149 | // Block scheduler 150 | uint32_t m_block_idx, n_block_idx; 151 | auto scheduler = Scheduler 152 | (shape_m, grouped_layout, block_idx_base + blockIdx.x); 153 | 154 | 155 | if constexpr (L2_SIDE_OPTIMIZATION) { 156 | int smid; 157 | asm("mov.u32 %0, %smid;\n" : "=r"(smid) :); 158 | int side = sideaware.sm_side_and_idx[smid] & 1; 159 | scheduler.block_idx = sideaware.sm_side_and_idx[smid] >> 1; 160 | scheduler.n_block_offset = !side * (SHAPE_N_LOWER / BLOCK_N); 161 | scheduler.grid_size /= 2; 162 | } 163 | 164 | // Pre-compute tile m/n for NUM_TILES_INITIAL tiles during warmup 165 | // Future tiles will be computed in the same thread that loads B scales (both are per-tile rather than per-block) 166 | if (threadIdx.x < NUM_TILES_INITIAL) { 167 | if (threadIdx.x*scheduler.grid_size < scheduler.num_blocks+scheduler.grid_size || kGemmType == GemmType::GroupedMasked) { 168 | scheduler.current_iter = threadIdx.x - 1; 169 | scheduler.get_next_block(m_block_idx, n_block_idx); 170 | smem_tile_scheduling[threadIdx.x].x = m_block_idx; 171 | smem_tile_scheduling[threadIdx.x].y = n_block_idx; 172 | if constexpr (kGemmType == GemmType::GroupedMasked) { 173 | smem_tile_scheduling[threadIdx.x].z = scheduler.curr_group_idx; 174 | smem_tile_scheduling[threadIdx.x].w = 0; 175 | } 176 | scheduler.current_iter = -1; 177 | } 178 | } 179 | 180 | // Helper lambda to fetch next tile from shared memory (precomputed way in advance) 181 | // this only works because we use static scheduling without work stealing, which has its own benefits... 182 | auto fetch_next_tile = [&](uint32_t& m_block_idx, uint32_t& n_block_idx) -> bool { 183 | scheduler.current_iter++; 184 | int idx = scheduler.current_iter % NUM_TILES_STORAGE; 185 | 186 | m_block_idx = smem_tile_scheduling[idx].x; 187 | n_block_idx = smem_tile_scheduling[idx].y; 188 | if constexpr (kGemmType == GemmType::GroupedMasked) { 189 | scheduler.curr_group_idx = smem_tile_scheduling[idx].z; 190 | } 191 | return (m_block_idx != 0xFFFFFFFF); 192 | }; 193 | 194 | // Initialize barriers (split over threads, it can be a lot with 10+ stages!) 195 | // Keep threads associated with 1st sub-processor (threads 0 to 31) free since they do other work above 196 | if (threadIdx.x == kNumMathThreads + 32) { 197 | #pragma unroll 198 | for (int i = 0; i < kNumStages; ++ i) { 199 | full_barriers_base[i].init(2); 200 | } 201 | cutlass::arch::fence_view_async_shared(); 202 | (kTMAMulticastEnabled) ? cutlass::arch::fence_barrier_init() : void(); 203 | } else if (threadIdx.x == kNumMathThreads + 64) { 204 | #pragma unroll 205 | for (int i = 0; i < kNumStages; ++ i) { 206 | empty_barriers_base[i].init(cluster_size * kNumMathThreads / 32); 207 | } 208 | cutlass::arch::fence_view_async_shared(); 209 | (kTMAMulticastEnabled) ? cutlass::arch::fence_barrier_init() : void(); 210 | } else if (threadIdx.x == kNumMathThreads + 96) { 211 | // Prefetch TMA descriptors at very beginning 212 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); 213 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); 214 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); 215 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); 216 | cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d_padded)); 217 | #pragma unroll 218 | for (int i = 0; i < 2; ++ i) { 219 | full_barrier_scales_b_base[i].init(1); 220 | empty_barrier_scales_b_base[i].init(kNumMathThreads / 32); 221 | } 222 | cutlass::arch::fence_view_async_shared(); 223 | (kTMAMulticastEnabled) ? cutlass::arch::fence_barrier_init() : void(); 224 | } 225 | 226 | // Synchronize all threads to make barrier visible in normal memory model (as late as possible) 227 | (kTMAMulticastEnabled) ? cute::cluster_sync() : __syncthreads(); 228 | 229 | // Updates pipeline stage & parity in as few SASS instructions as possible 230 | int s = 0, last_s = -1, parity = -1; // persistent context across loop iterations 231 | auto next_stage = [&]() { 232 | bool wrap = (s == kNumStages-1); 233 | last_s = s; 234 | s++; 235 | if (wrap) { 236 | s = 0; 237 | parity++; 238 | } 239 | }; 240 | 241 | if (threadIdx.x >= kNumMathThreads) { 242 | // TMA warp-groups for loading data - we split the threads into 243 | // 1) Calculating future tile m/n and loading B scales per tile (1 thread) 244 | // 2) Loading A data & scales (1 thread) 245 | // 3) Loading B data (multiple threads/warps, expensive due to L2 side awareness, optimized PTX) 246 | // 247 | // (3) previously supported multiple warps to parallelise the L2 side calculations... 248 | // but slower after other crazy optimizations so back to 1 warp (but multiple threads per warp) 249 | 250 | cutlass::arch::warpgroup_reg_dealloc(); 251 | parity = 1; // producer starts with parity=1 (no wait) 252 | 253 | if (warp_idx == kNumMathWarps) { 254 | // TODO - explain this code, or better yet, rewrite all of it! 255 | // TODO - add back "fast path" when everything is aligned, it was much faster *sigh* (or rewrite the whole thing!!!) 256 | constexpr int CHUNK_SIZE = (BLOCK_N % 16) ? 8 : ((BLOCK_N % 32) ? 16 : 32); // largest chunk size that can divide BLOCK_N 257 | constexpr int NUM_CHUNKS = (BLOCK_N + CHUNK_SIZE - 1) / CHUNK_SIZE; 258 | if constexpr (L2_SIDE_OPTIMIZATION && NUM_CHUNKS > 1) { 259 | if (lane_idx >= NUM_CHUNKS) { 260 | return; 261 | } 262 | } else { 263 | elect_or_exit(); 264 | } 265 | 266 | // Create a lambda to handle loading B data with different starting k_idx 267 | int loader_idx = 0; 268 | int loader_tid = threadIdx.x-kNumMathThreads; 269 | 270 | constexpr bool aligned_n = (SHAPE_N % BLOCK_N == 0 && SHAPE_N_LOWER % BLOCK_N == 0 && BLOCK_N % CHUNK_SIZE == 0); 271 | 272 | if constexpr (L2_SIDE_OPTIMIZATION) { 273 | int current_shape_n = (scheduler.n_block_offset) ? SHAPE_N : SHAPE_N_LOWER; 274 | int start_page_offset = reinterpret_cast(gmem_b) % (2048*1024); 275 | 276 | __nv_fp8_e4m3* b_page_start = gmem_b - start_page_offset; 277 | uintptr_t b_page_start_u64 = reinterpret_cast(b_page_start); 278 | uint32_t b_page_start_u32[2]; 279 | b_page_start_u32[0] = (uint32_t)(b_page_start_u64 & 0xFFFFFFFF); 280 | b_page_start_u32[1] = (uint32_t)(b_page_start_u64 >> 32L); 281 | 282 | smem_b_base += (NUM_CHUNKS > 1) ? lane_idx * (CHUNK_SIZE * BLOCK_K) : 0; 283 | 284 | int global_base_offset = start_page_offset; 285 | int lane_chunk_start = (NUM_CHUNKS > 1) ? (lane_idx * CHUNK_SIZE) : 0; 286 | 287 | // Persistently schedule over blocks to load B 288 | #pragma unroll 1 289 | while (fetch_next_tile(m_block_idx, n_block_idx)) { 290 | int n = n_block_idx * BLOCK_N; 291 | 292 | int remaining_n = current_shape_n - n; 293 | n += lane_chunk_start; 294 | int n_side = (n >= SHAPE_N_HALF) ? 1 : 0; 295 | int n_half = (n_side * (-SHAPE_N_HALF)) + n; 296 | int n_dst_base = n_half + (n_half & ~31); // shift everything after bit 5 to the left by 1 297 | uint32_t tile_base_offset = global_base_offset + (n_dst_base * 128); 298 | if constexpr (kGemmType == GemmType::GroupedContiguous) { 299 | int group_offset = __ldg(grouped_layout + m_block_idx * BLOCK_M); 300 | tile_base_offset += (SHAPE_N * SHAPE_K) * group_offset; 301 | } 302 | 303 | // Declare early so we can use a lambda to compile 2 optimized paths based on their values 304 | int num_bytes_total; 305 | int num_bytes; 306 | 307 | // ---------------------------------------------------------------------------------------- 308 | // Check sideaware_kernel.cuh for a simpler version of the side aware algorithm without PTX 309 | // ---------------------------------------------------------------------------------------- 310 | // Custom PTX implementation of side-aware memory copy for B matrix 311 | // optimized for efficient SASS at a random driver in time (12.8) 312 | // ... because why not? 313 | // ---------------------------------------------------------------------------------------- 314 | auto load_b_for_every_k = [&]() { 315 | #pragma unroll 1 316 | for (int k_idx = 0; k_idx < SHAPE_K_SCALES; k_idx++, tile_base_offset += BLOCK_K * SHAPE_N) { 317 | uint32_t smem_int_mbar = cute::cast_smem_ptr_to_uint(&full_barriers_base[s]); 318 | uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_b_base + s * SMEM_AB_SIZE_PER_STAGE); 319 | uint32_t smem_int_empty_mbar = cute::cast_smem_ptr_to_uint(&empty_barriers_base[s]); 320 | uint32_t gmem_address_u32[2]; 321 | 322 | // wait on mbarrier (can't remember if this ended up any better than the baseline barrier.wait) 323 | asm volatile( 324 | "{\n" 325 | " .reg .pred p;\n" 326 | " LAB_WAIT:\n" 327 | " mbarrier.try_wait.parity.shared::cta.b64 p, [%0], %1, 10000000;\n" 328 | " @p bra DONE;\n" 329 | " bra LAB_WAIT;\n" 330 | " DONE:\n" 331 | "}\n" 332 | : : "r"(smem_int_empty_mbar), "r"(parity) : "memory" 333 | ); 334 | 335 | // optimized PTX version of next_stage() 336 | // this was required because we had multiple warps interleaving iterations of the loop 337 | // not sure if this is actually faster than the baseline version with a single warp or not 338 | // but it's here now, so enjoy I guess? 339 | asm volatile( 340 | "{\n" 341 | " .reg .pred p;\n" 342 | " setp.ge.s32 p, %0, %2;\n" 343 | " @p add.s32 %0, %0, %3;\n" 344 | " @p add.s32 %1, %1, 1;\n" 345 | " @!p add.s32 %0, %0, 1;\n" 346 | "}\n" 347 | : "+r"(s), "+r"(parity) 348 | : "n"(kNumStages - 1), "n"(1 - kNumStages) 349 | : "memory" 350 | ); 351 | 352 | if (num_bytes > 0) { 353 | // Check sideaware_kernel.cuh for a saner version of the side aware algorithm without PTX 354 | // 355 | // Determine the desired side based on the l2_hash_bits of the address (using popc) 356 | // then use this to adjust the offset as required (see sideaware_kernel.cuh) 357 | // the black magic part of this PTX is related to how we handle the unaligned case 358 | // more efficiently than in my other non-PTX version of the algorithm 359 | // it uses 'add.cc' and 'addc.u32' in a clever way to save a few instructions 360 | asm volatile( 361 | "{\n" 362 | " .reg .u32 lower_bits;\n" 363 | " .reg .u32 tmp;\n" 364 | " .reg .b64 address;\n" 365 | " add.cc.u32 lower_bits, %2, %4;\n" 366 | " and.b32 tmp, lower_bits, %5;\n" 367 | " xor.b32 tmp, tmp, %6;\n" 368 | " popc.b32 tmp, tmp;\n" 369 | " and.b32 tmp, tmp, 0x1;\n" 370 | " xor.b32 tmp, tmp, 0x1;\n" // invert due to bit 21 in sideaware.cu hash 371 | " mad.lo.u32 %0, tmp, 4096, lower_bits;\n" 372 | " addc.u32 %1, %3, 0;\n" 373 | : "=r"(gmem_address_u32[0])/*0*/, "=r"(gmem_address_u32[1])/*1*/ 374 | : "r"(b_page_start_u32[0]) /*2*/, "r"(b_page_start_u32[1]) /*3*/, 375 | "r"(tile_base_offset) /*4*/, "r"(l2_hash_bits) /*5*/, "r"(n_side) /*6*/ 376 | : "memory" 377 | ); 378 | 379 | if constexpr (NUM_CHUNKS == 1) { 380 | // Fast path with a single thread active per warp 381 | asm volatile( 382 | " mov.b64 address, {%0, %1};\n" 383 | " cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%2], [address], %3, [%4], policy;\n" 384 | " mbarrier.arrive.expect_tx.shared::cta.b64 _, [%4], %5; \n" 385 | "}\n" 386 | : : "r"(gmem_address_u32[0]), "r"(gmem_address_u32[1]), 387 | "r"(smem_int_ptr), "r"(num_bytes), "r"(smem_int_mbar), "r"(num_bytes_total) 388 | : "memory" 389 | ); 390 | } else { 391 | // Slow path with multiple threads where the compiler will create a loop for the TMA 392 | // the SASS isn't optimal but extremely difficult to improve further with 'just' PTX 393 | // (+ setp/@p to only mbarrier.arrive on a single thread) 394 | asm volatile( 395 | " .reg .pred p;\n" 396 | " setp.eq.u32 p, %6, 0;\n" 397 | " mov.b64 address, {%0, %1};\n" 398 | " cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%2], [address], %3, [%4], policy;\n" 399 | "@p mbarrier.arrive.expect_tx.shared::cta.b64 _, [%4], %5; \n" 400 | "}\n" 401 | : : "r"(gmem_address_u32[0]), "r"(gmem_address_u32[1]), 402 | "r"(smem_int_ptr), "r"(num_bytes), "r"(smem_int_mbar), "r"(num_bytes_total), 403 | "r"(lane_idx) 404 | : "memory" 405 | ); 406 | } 407 | } 408 | } 409 | }; 410 | 411 | // Weird perf issues I didn't understand creating the policy in the loop, so gave up and put it here 412 | // A sane person would use a variable rather than PTX scope via {}, but I never claimed to be sane. 413 | asm volatile("{\n .reg .b64 policy;\n" 414 | " .reg .pred p;\n" 415 | " setp.eq.u32 p, %0, %1;\n" 416 | " @p " POLICY_TINY_A_READ_B 417 | " @!p " POLICY_BIG_A_READ_B 418 | : : "r"(shape_m), "r"(BLOCK_M) : "memory"); 419 | 420 | if (!aligned_n && remaining_n < BLOCK_N) { 421 | // Slow path with dynamic num_byte 422 | int n_to_load_warp = max(0, min(remaining_n, BLOCK_N)); 423 | int n_to_load_lane = (n_to_load_warp - lane_chunk_start); 424 | 425 | num_bytes_total = n_to_load_warp * BLOCK_K; 426 | num_bytes = (n_to_load_lane > CHUNK_SIZE) ? (CHUNK_SIZE*BLOCK_K) : (n_to_load_lane*BLOCK_K); 427 | load_b_for_every_k(); 428 | } else { 429 | // Fast path where the compiler realises num_bytes is known at compile time 430 | num_bytes_total = BLOCK_N * BLOCK_K; 431 | num_bytes = CHUNK_SIZE * BLOCK_K; 432 | load_b_for_every_k(); 433 | } 434 | asm volatile("}\n" : : ); // end 'policy' variable scope 435 | } 436 | } else { 437 | // Legacy approach without L2 side optimization 438 | while (fetch_next_tile(m_block_idx, n_block_idx)) { 439 | for (int k_idx = 0; k_idx < SHAPE_K_SCALES; k_idx++) { 440 | auto& full_barrier = full_barriers_base[s]; 441 | uint64_t* full_barrier64 = reinterpret_cast(&full_barrier); 442 | empty_barriers_base[s].wait(parity); 443 | tma_copy(&tensor_map_b, full_barrier64, smem_b_base + s * SMEM_AB_SIZE_PER_STAGE, 444 | k_idx * BLOCK_K, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), cluster_size); 445 | full_barriers_base[s].arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); 446 | s++; 447 | if (s >= kNumStages) { 448 | s -= kNumStages; 449 | parity++; 450 | } 451 | } 452 | } 453 | } 454 | 455 | // To safely deconstruct distributed shared barriers, we need another round of empty waits 456 | if constexpr (kTMAMulticastEnabled) { 457 | if (loader_idx == 0 && lane_idx == 0) { 458 | for (int i = 0; i < kNumStages+1; i++) { 459 | empty_barriers_base[s].wait(parity); 460 | next_stage(); 461 | } 462 | } 463 | } 464 | } else if (threadIdx.x == kNumMathThreads + 64) { 465 | elect_or_exit(); // tell nvcc this is single-threaded (bad codegen otherwise on 12.8) 466 | // Persistently schedule over blocks to load A/A_scales 467 | while (fetch_next_tile(m_block_idx, n_block_idx)) { 468 | if (!kTMAMulticastEnabled || cute::block_rank_in_cluster() == 0) { 469 | #pragma unroll 2 // TODO: only if divisible by 2 470 | for (int k_idx = 0; k_idx < SHAPE_K_SCALES; k_idx++) { 471 | // Wait consumer release 472 | empty_barriers_base[s].wait(parity); 473 | // Issue TMA A with broadcasting 474 | auto& full_barrier = full_barriers_base[s]; 475 | uint64_t* full_barrier64 = reinterpret_cast(&full_barrier); 476 | tma_copy(&tensor_map_a, full_barrier64, smem_a_base + s * SMEM_AB_SIZE_PER_STAGE, 477 | k_idx * BLOCK_K, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), cluster_size); 478 | tma_copy(&tensor_map_scales_a, full_barrier64, smem_scales_a_base + s * BLOCK_M, 479 | m_block_idx * BLOCK_M, scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx), cluster_size); 480 | full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); 481 | next_stage(); 482 | } 483 | } else { 484 | // "block_rank_in_cluster() != 0" only need to release barriers at the right time, no TMAs needed 485 | #pragma unroll 1 486 | for (int k_idx = 0; k_idx < SHAPE_K_SCALES; k_idx++) { 487 | empty_barriers_base[s].wait(parity); 488 | full_barriers_base[s].arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); 489 | next_stage(); 490 | } 491 | } 492 | } 493 | // To safely deconstruct distributed shared barriers, we need another round of empty waits 494 | if constexpr (kTMAMulticastEnabled) { 495 | for (int i = 0; i < kNumStages + 1; i++) { 496 | empty_barriers_base[s].wait(parity); 497 | next_stage(); 498 | } 499 | } 500 | } else if (threadIdx.x == kNumMathThreads + 96) { 501 | // Load B Scales per-tile + future tile scheduling (many tiles in advance) 502 | elect_or_exit(); 503 | auto future_scheduler = scheduler; // deep copy of scheduler 504 | future_scheduler.current_iter = NUM_TILES_INITIAL - 1; 505 | 506 | // Load scales B via TMA Load per-tile rather than per-K_BLOCK 507 | // previously done with global memory loads, which was OK but forced synchronization between warpgroups 508 | // hardcoded to always be doubled-buffered (equivalent to s=2) which is more than enough (but s=1 isn't!) 509 | #pragma unroll 1 510 | while (fetch_next_tile(m_block_idx, n_block_idx)) { 511 | auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); 512 | auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; 513 | 514 | // Decide the number of scales B to load 515 | // this is inside the loop because it's index-dependent for N != 128 and non-uniform ScaleB 516 | DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); 517 | uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; 518 | if constexpr (not kMustUseUniformedScaleB) { 519 | num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; 520 | num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; 521 | } 522 | uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); 523 | 524 | // explicit idx/parity calculation since (same or fewer instructions than increment+wrap) 525 | int barrier_idx = scheduler.current_iter & 1; 526 | int scales_parity = (scheduler.current_iter & 2) ? 0 : 1; // init=1 for producer (0 for consumer) 527 | empty_barrier_scales_b_base[barrier_idx].wait(scales_parity); 528 | 529 | auto& full_barrier = full_barrier_scales_b_base[barrier_idx]; 530 | cute::SM90_BULK_COPY_G2S::copy(local_scales_b, reinterpret_cast(&full_barrier), 531 | smem_scales_b_base + barrier_idx * kNumScalesB, num_scales_b * sizeof(float)); 532 | full_barrier.arrive_and_expect_tx(num_scales_b * sizeof(float)); 533 | 534 | ////// TODO - explain future tile scheduling, currently single threaded with implicit synchronization 535 | DG_STATIC_ASSERT(NUM_TILES_INITIAL > kNumStages+8, "NUM_TILES_INITIAL should be much greater than kNumStages"); 536 | uint32_t future_m_block_idx, future_n_block_idx; 537 | future_scheduler.get_next_block(future_m_block_idx, future_n_block_idx); 538 | 539 | int tile_smem_idx = future_scheduler.current_iter % NUM_TILES_STORAGE; 540 | smem_tile_scheduling[tile_smem_idx].x = future_m_block_idx; 541 | smem_tile_scheduling[tile_smem_idx].y = future_n_block_idx; 542 | if constexpr (kGemmType == GemmType::GroupedMasked) { 543 | smem_tile_scheduling[tile_smem_idx].z = future_scheduler.curr_group_idx; 544 | smem_tile_scheduling[tile_smem_idx].w = 0; 545 | } 546 | } 547 | } 548 | } else { 549 | // Math warp-groups for WGMMA 550 | cutlass::arch::warpgroup_reg_alloc(); 551 | parity = 0; // consumer starts with parity=0 (must wait) 552 | 553 | // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers 554 | const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); 555 | const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; 556 | 557 | // Accumulation for WGMMA or CUDA promotion 558 | // We use 2 temporary accumulators for WGMMAs and one final accumulator for CUDA promotion 559 | float accum0[WGMMA::kNumAccum], accum1[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; 560 | __nv_bfloat162 final_accum_bf16[WGMMA::kNumAccum / 2]; 561 | 562 | // The descriptors are basically always the same plus an offset, so we precompute them here 563 | uint64_t desc_a_base = make_smem_desc(smem_a_base + math_wg_idx * WGMMA::M * BLOCK_K, 1); 564 | uint64_t desc_b_base = make_smem_desc(smem_b_base, 1); 565 | 566 | // Empty barrier arrival 567 | auto empty_barrier_arrive = [&](Barrier* barrier) { 568 | if constexpr (!kTMAMulticastEnabled) { 569 | lane_idx == 0 ? barrier->arrive() : void(); 570 | } else { 571 | lane_idx < cluster_size ? barrier->arrive(lane_idx) : void(); 572 | } 573 | }; 574 | 575 | // Keep track of the previous tile's position to do its TMA store in the next loop iteration 576 | uint32_t old_global_idx; 577 | int old_n_block_idx = -1; 578 | 579 | //--------------------------------------------------------------------------------- 580 | // Lambda to store tile N-1 in iteration N and the final tile after the loop 581 | // final_accum (64 registers for N=128) ==> SMEM via STSM ==> global via TMA 582 | //--------------------------------------------------------------------------------- 583 | auto final_accum_to_bf16 = [&]() { 584 | for (int i = 0; i < WGMMA::kNumAccum / 2; ++ i) { 585 | final_accum_bf16[i] = __float22bfloat162_rn({final_accum[i*2+0], final_accum[i*2+1]}); 586 | } 587 | }; 588 | auto store_tile = [&] (int tile_s, int start=0, int end=WGMMA::kNumAccum, bool skip_to_bf16=false) { 589 | int current_shape_n = (scheduler.n_block_offset) ? SHAPE_N : SHAPE_N_LOWER; 590 | 591 | // Write final_accum to shared memory using STSM 592 | // Padded to avoid up to 8x(!) shared memory bank conflicts 593 | auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_a_base + tile_s * (SMEM_AB_SIZE_PER_STAGE)); 594 | bool partially_oob = (old_n_block_idx * BLOCK_N) > (SHAPE_N - BLOCK_N) && (SHAPE_N % BLOCK_N) > 0; 595 | uint32_t BLOCK_N_STORE = partially_oob ? BLOCK_N : BLOCK_N_PADDED; 596 | 597 | // Only process part of the tile at a time if possible 598 | if (start == 0 && !skip_to_bf16) final_accum_to_bf16(); 599 | 600 | // Write back to shared memory using STSM 601 | DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); 602 | #pragma unroll 603 | for (auto i = start / 8; i < end / 8; ++ i) { 604 | SM90_U32x4_STSM_N::copy( 605 | final_accum_bf16[i * 4 + 0], final_accum_bf16[i * 4 + 1], 606 | final_accum_bf16[i * 4 + 2], final_accum_bf16[i * 4 + 3], 607 | smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N_STORE + i * 16 + 8 * (lane_idx / 16) 608 | ); 609 | } 610 | 611 | // TMA store on final iteration 612 | if (end >= WGMMA::kNumAccum && start < WGMMA::kNumAccum) { 613 | if constexpr (WGMMA::kNumAccum % 8 != 0) { 614 | SM90_U32x2_STSM_N::copy( 615 | final_accum_bf16[WGMMA::kNumAccum / 8 * 4 + 0], final_accum_bf16[WGMMA::kNumAccum / 8 * 4 + 1], 616 | smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N_STORE + WGMMA::kNumAccum / 8 * 16 617 | ); 618 | } 619 | 620 | cute::tma_store_fence(); 621 | 622 | // sync per-warpgroup rather than per-threadgroup and issue the TMA per warpgroup 623 | // this prevents both warpgroups from being idle at the same time as much as possible 624 | asm volatile("bar.sync %0, 128;\n" :: "r"(math_wg_idx)); 625 | 626 | // Use TMA store to write back to global memory (per warpgroup rather than per threadgroup) 627 | // using threads which aren't part of a subprocessor that's active in the producer warpgroup 628 | if ((threadIdx.x == 96 || threadIdx.x == 224)) { 629 | uint64_t gmem_int_desc = reinterpret_cast(partially_oob ? &tensor_map_d : &tensor_map_d_padded); 630 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_d + (math_wg_idx * BLOCK_N_STORE * 64))); 631 | 632 | // A sane person wouldn't use PTX here, but we're well past that point by now 633 | asm volatile("{\n .reg .b64 policy;\n" 634 | " .reg .pred p;\n" 635 | " setp.eq.u32 p, %0, %1;\n" 636 | " @p " POLICY_TINY_A_WRITE_D 637 | " @!p " POLICY_BIG_A_WRITE_D 638 | " setp.eq.u32 p, %2, 1;\n" 639 | " @p cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.L2::cache_hint [%3, {%5, %6}], [%4], policy;\n" 640 | " @!p cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.L2::cache_hint [%3, {0, %7, %6}], [%4], policy;\n" 641 | "}\n" 642 | : : "r"(shape_m) /*0*/, "r"(BLOCK_M) /*1*/, "r"((uint32_t)partially_oob) /*2*/, 643 | "l"(gmem_int_desc) /*3*/, "r"(smem_int_ptr) /* 4 */, 644 | "r"(old_n_block_idx * BLOCK_N) /* 5 (2D only)*/, 645 | "r"(old_global_idx + (math_wg_idx * 64)) /* 6 (2D/3D)*/, 646 | "r"(old_n_block_idx) /* 7 (3D only)*/ 647 | : "memory"); 648 | 649 | cute::tma_store_arrive(); 650 | cute::tma_store_wait<0>(); 651 | } 652 | asm volatile("bar.sync %0, 128;\n" :: "r"(math_wg_idx)); 653 | } 654 | }; 655 | 656 | // Persistently schedule over blocks 657 | while (fetch_next_tile(m_block_idx, n_block_idx)) { 658 | // determine the B-scales address for this tile & barrier idx/parity for this tile 659 | int barrier_scales_b = (scheduler.current_iter & 1); 660 | int parity_scales_b = (scheduler.current_iter & 2) ? 1 : 0; // init=0 for consumer 661 | float* smem_scales_b = smem_scales_b_base + barrier_scales_b * kNumScalesB; 662 | // Decide the number of scales B to use (varies when N != 128 and non-uniform ScaleB) 663 | uint32_t num_former_iters = BLOCK_N / 8; 664 | if constexpr (not kMustUseUniformedScaleB) { 665 | num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; 666 | } 667 | // wait for the TMA load to fill them 668 | full_barrier_scales_b_base[barrier_scales_b].wait(parity_scales_b); 669 | 670 | // persistent across calls (write in wgmma_prepare_scales, read in promote_with_scales) 671 | //float scale_0_0[2], scale_1_0[2], scale_0_1[2], scale_1_1[2]; 672 | float scale_a_0[2], scale_a_1[2], scale_b_0[2], scale_b_1[2]; 673 | 674 | // --------------------------------------------------------------------------------- 675 | // Lambda to execute WGMMA on tensor cores and prepare scales for promotion 676 | // --------------------------------------------------------------------------------- 677 | auto wgmma_prepare_scales = [&](int idx, float* scales_b, bool accum_same_scale=false) { 678 | // Precompute descriptors to reduce instructions in inner loop 679 | uint64_t desc_a = desc_a_base + (s * (SMEM_AB_SIZE_PER_STAGE >> 4)); 680 | uint64_t desc_b = desc_b_base + (s * (SMEM_AB_SIZE_PER_STAGE >> 4)); 681 | float* accum = idx ? accum1 : accum0; 682 | 683 | // Wait TMA arrivals 684 | full_barriers_base[s].wait(parity); 685 | 686 | // Commit WGMMA instructions 687 | for (int i = 0; i < WGMMA::kNumAccum; ++i) 688 | warpgroup_fence_operand(accum[i]); 689 | warpgroup_arrive(); 690 | for (int k = 0; k < (BLOCK_K / WGMMA::K); k++) { 691 | WGMMA::wgmma(desc_a, desc_b, accum, k || accum_same_scale); 692 | desc_a += (WGMMA::K >> 4); 693 | desc_b += (WGMMA::K >> 4); 694 | } 695 | warpgroup_commit_batch(); 696 | for (int i = 0; i < WGMMA::kNumAccum; ++i) 697 | warpgroup_fence_operand(accum[i]); 698 | 699 | // Read A & B scales (OK between warpgroup_arrive and empty_barrier_arrive with WGMMA double-buffering) 700 | if (!accum_same_scale) { 701 | scale_a_0[idx] = ld_shared(smem_scales_a_base + s*BLOCK_M + r_0); 702 | scale_a_1[idx] = ld_shared(smem_scales_a_base + s*BLOCK_M + r_1); 703 | scale_b_0[idx] = ld_shared(scales_b); 704 | if (!kMustUseUniformedScaleB) { 705 | scale_b_1[idx] = ld_shared(scales_b + SHAPE_K_SCALES); 706 | } 707 | } 708 | }; 709 | 710 | // --------------------------------------------------------------------------------- 711 | // Lambda to promote with scaling factors on CUDA cores 712 | // --------------------------------------------------------------------------------- 713 | auto promote_with_scales = [&](int idx, bool add = true) { 714 | float* dst = final_accum; 715 | float* src = idx ? accum1 : accum0; 716 | 717 | // Calculate scaling factors used when promoting into final_accum later in promote_with_scales 718 | float scale_0_0 = scale_a_0[idx] * scale_b_0[idx]; 719 | float scale_1_0 = scale_a_1[idx] * scale_b_0[idx]; 720 | float scale_0_1, scale_1_1; 721 | if (!kMustUseUniformedScaleB) { 722 | scale_0_1 = scale_a_0[idx] * scale_b_1[idx]; 723 | scale_1_1 = scale_a_1[idx] * scale_b_1[idx]; 724 | } 725 | for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { 726 | bool predicate = kMustUseUniformedScaleB or i < num_former_iters; 727 | dst[i*4+0] = (add ? dst[i*4+0] : 0) + (predicate ? scale_0_0 : scale_0_1) * src[i*4+0]; 728 | dst[i*4+1] = (add ? dst[i*4+1] : 0) + (predicate ? scale_0_0 : scale_0_1) * src[i*4+1]; 729 | dst[i*4+2] = (add ? dst[i*4+2] : 0) + (predicate ? scale_1_0 : scale_1_1) * src[i*4+2]; 730 | dst[i*4+3] = (add ? dst[i*4+3] : 0) + (predicate ? scale_1_0 : scale_1_1) * src[i*4+3]; 731 | } 732 | }; 733 | 734 | constexpr int idx_0 = 0; 735 | constexpr int idx_1 = DP_SCALE_256 ? 0 : 1; 736 | constexpr int idx_2 = DP_SCALE_256 ? 1 : 0; 737 | constexpr int idx_3 = DP_SCALE_256 ? 1 : 1; 738 | 739 | if constexpr (DOUBLE_PUMP) { 740 | // Double Pumped (new-ish path) 741 | assert(SHAPE_K_SCALES % 2 == 0); 742 | 743 | wgmma_prepare_scales(idx_0, smem_scales_b + 0, false); 744 | next_stage(); 745 | wgmma_prepare_scales(idx_1, smem_scales_b + 1, DP_SCALE_256); 746 | 747 | int tile_s = last_s; 748 | final_accum_to_bf16(); 749 | 750 | warpgroup_wait<1>(); 751 | next_stage(); 752 | 753 | if (old_n_block_idx != -1) { 754 | if constexpr (kNumMathThreads > 128) { 755 | asm volatile("bar.sync 2, %0;\n" :: "n"(kNumMathThreads)); 756 | if (math_wg_idx == 0) { 757 | store_tile(tile_s, 0, WGMMA::kNumAccum, true); 758 | empty_barrier_arrive(&empty_barriers_base[tile_s]); 759 | } 760 | } else { 761 | store_tile(tile_s, 0, WGMMA::kNumAccum, true); 762 | empty_barrier_arrive(&empty_barriers_base[tile_s]); 763 | } 764 | 765 | if constexpr (!DP_SCALE_256) { promote_with_scales(0, false); } 766 | wgmma_prepare_scales(idx_2, smem_scales_b + 2, false); 767 | 768 | if (kNumMathThreads > 128 && math_wg_idx == 1) { 769 | store_tile(tile_s, 0, WGMMA::kNumAccum, true); 770 | empty_barrier_arrive(&empty_barriers_base[tile_s]); 771 | } 772 | } else { 773 | empty_barrier_arrive(&empty_barriers_base[tile_s]); 774 | if constexpr (!DP_SCALE_256) { promote_with_scales(0, false); } 775 | wgmma_prepare_scales(idx_2, smem_scales_b + 2, false); 776 | } 777 | 778 | warpgroup_wait<1>(); 779 | empty_barrier_arrive(&empty_barriers_base[last_s]); 780 | next_stage(); 781 | 782 | if constexpr (!DP_SCALE_256) { promote_with_scales(1);} 783 | wgmma_prepare_scales(idx_3, smem_scales_b + 3, DP_SCALE_256); 784 | 785 | if constexpr (DP_SCALE_256) { promote_with_scales(0, false); } 786 | warpgroup_wait<1>(); 787 | empty_barrier_arrive(&empty_barriers_base[last_s]); 788 | next_stage(); 789 | if constexpr (!DP_SCALE_256) { promote_with_scales(0); } 790 | 791 | #pragma unroll 2 792 | for (int k_loop = 4; k_loop < SHAPE_K_SCALES; k_loop += 4) { 793 | wgmma_prepare_scales(idx_0, smem_scales_b + k_loop + 0, false); 794 | warpgroup_wait<1>(); 795 | empty_barrier_arrive(&empty_barriers_base[last_s]); 796 | 797 | if constexpr (!DP_SCALE_256) { promote_with_scales(1); } 798 | 799 | next_stage(); 800 | wgmma_prepare_scales(idx_1, smem_scales_b + k_loop + 1, DP_SCALE_256); 801 | 802 | if constexpr (DP_SCALE_256) { promote_with_scales(1); } 803 | warpgroup_wait<1>(); 804 | empty_barrier_arrive(&empty_barriers_base[last_s]); 805 | next_stage(); 806 | if constexpr (!DP_SCALE_256) { promote_with_scales(0);} 807 | 808 | wgmma_prepare_scales(idx_2, smem_scales_b + k_loop + 2, false); 809 | warpgroup_wait<1>(); 810 | empty_barrier_arrive(&empty_barriers_base[last_s]); 811 | 812 | if constexpr (!DP_SCALE_256) { promote_with_scales(1); } 813 | 814 | next_stage(); 815 | wgmma_prepare_scales(idx_3, smem_scales_b + k_loop + 3, DP_SCALE_256); 816 | 817 | if constexpr (DP_SCALE_256) { promote_with_scales(0); } 818 | warpgroup_wait<0>(); // sigh, damnit nvcc! TODO: 1 but hacked to 0 via SASS because we know better 819 | empty_barrier_arrive(&empty_barriers_base[last_s]); 820 | next_stage(); 821 | if constexpr (!DP_SCALE_256) { promote_with_scales(0); } 822 | } 823 | 824 | // TODO: tail when K is not a multiple of 4 825 | if constexpr (SHAPE_K_SCALES == 4) { 826 | warpgroup_wait<0>(); 827 | } 828 | empty_barrier_arrive(&empty_barriers_base[last_s]); 829 | promote_with_scales(1); 830 | 831 | } else { 832 | // Not Double Pumped (old-ish path) 833 | // WGMMA 0 834 | wgmma_prepare_scales(0, smem_scales_b + 0); 835 | 836 | if constexpr (SHAPE_K_SCALES > 1) { 837 | // WGMMA 1 838 | next_stage(); 839 | wgmma_prepare_scales(1, smem_scales_b + 1); 840 | 841 | // Wait for WGMMA 0 (not the one we just issued) and let the producer know it can reuse its memory 842 | warpgroup_wait<1>(); 843 | if constexpr (kNumMathThreads > 128) { 844 | asm volatile("bar.sync 2, %0;\n" :: "n"(kNumMathThreads)); 845 | } 846 | if (old_n_block_idx != -1) { 847 | store_tile(last_s); 848 | } 849 | empty_barrier_arrive(&empty_barriers_base[last_s]); 850 | } else { 851 | // Special case: single K_BLOCK so we don't need any other WGMMAs and we can just wait on WGMMA 0 852 | warpgroup_wait<0>(); 853 | if constexpr (kNumMathThreads > 128) { 854 | asm volatile("bar.sync 2, %0;\n" :: "n"(kNumMathThreads)); 855 | } 856 | if (old_n_block_idx != -1) { 857 | store_tile(s); 858 | } 859 | empty_barrier_arrive(&empty_barriers_base[s]); 860 | } 861 | 862 | // Promote without accumulation for WGMMA 0 (so we don't need to clear the registers) 863 | promote_with_scales(0, false); 864 | 865 | // KEY LOOP: This is where most of the WGMMAs usually happen (1 iteration per 2 K_BLOCK) 866 | #pragma unroll kNumUnroll 867 | for (int k_loop = 2; k_loop < SHAPE_K_SCALES-1; k_loop += 2) { 868 | next_stage(); 869 | wgmma_prepare_scales(0, smem_scales_b + k_loop); 870 | 871 | warpgroup_wait<1>(); // Wait for previous WGMMA (but not the one we just issued) and notify producer 872 | empty_barrier_arrive(empty_barriers_base + last_s); 873 | promote_with_scales(1); 874 | 875 | next_stage(); 876 | wgmma_prepare_scales(1, smem_scales_b + k_loop + 1); 877 | 878 | // Workaround to avoid NVCC/ptxas's warning "wgmma.mma_async instructions are serialized" 879 | // If we don't wait for all WGMMAs at loop boundary, compiler screws things up (on 12.8) 880 | // TODO: is there any way at all to avoid this when not fully unrolled? 881 | if (k_loop/2 % kNumUnroll == 0) warpgroup_wait<0>(); 882 | else warpgroup_wait<1>(); 883 | empty_barrier_arrive(empty_barriers_base + last_s); 884 | promote_with_scales(0); 885 | } 886 | // TODO: do we need a "dynamic tail" to avoid "instructions are serialized" with partial unroll? 887 | 888 | next_stage(); 889 | if constexpr (SHAPE_K_SCALES % 2 == 1 && SHAPE_K_SCALES > 1) { 890 | // Special case: K_BLOCK is not a multiple of 2 (e.g. K=384 with K_BLOCK=128) 891 | // We need to do one last WGMMA to complete the tile 892 | wgmma_prepare_scales(0, smem_scales_b + SHAPE_K_SCALES - 1); 893 | warpgroup_wait<1>(); // implicit in warpgroup_wait<0> workaround above 894 | empty_barrier_arrive(empty_barriers_base + last_s); 895 | promote_with_scales(1); 896 | 897 | next_stage(); 898 | warpgroup_wait<0>(); 899 | empty_barrier_arrive(empty_barriers_base + last_s); 900 | promote_with_scales(0); 901 | } else { 902 | // Usual case: we just need to promote the results of the last WGMMA 903 | warpgroup_wait<0>(); // implicit in warpgroup_wait<0> workaround above 904 | empty_barrier_arrive(empty_barriers_base + last_s); 905 | promote_with_scales(1); 906 | } 907 | } 908 | 909 | if (lane_idx == 0) { // Not a cluster barrier so can't use empty_barrier_arrive 910 | empty_barrier_scales_b_base[barrier_scales_b].arrive(); 911 | } 912 | 913 | // Write D for this tile while processing the next tile in the next loop iteration 914 | // Need to wait for space to store D (reusing memory from a stage of A/B) and store n/idx 915 | old_global_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); 916 | old_n_block_idx = n_block_idx; 917 | } 918 | 919 | if (scheduler.current_iter > 0) { 920 | // Store the final tile to global memory 921 | store_tile(s); 922 | empty_barrier_arrive(&empty_barriers_base[s]); 923 | } 924 | } 925 | #else 926 | if (blockIdx.x == 0 and threadIdx.x == 0) 927 | DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); 928 | #endif 929 | } 930 | 931 | template 936 | class Gemm { 937 | private: 938 | using Barrier = cuda::barrier; 939 | 940 | public: 941 | Gemm() = default; 942 | 943 | static void run(__nv_bfloat16* gmem_d, __nv_fp8_e4m3* gmem_b, 944 | float* scales_b, int* grouped_layout, 945 | uint32_t shape_m, 946 | const CUtensorMap& tma_a_desc, 947 | const CUtensorMap& tma_b_desc, 948 | const CUtensorMap& tma_scales_a_desc, 949 | const CUtensorMap& tma_d_desc, 950 | const CUtensorMap& tma_d_padded_desc, 951 | cudaStream_t stream, int num_sms, uint32_t smem_size, 952 | unsigned char* gpu_side_index) { 953 | // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps 954 | constexpr uint32_t kNumTMAThreads = 256; 955 | constexpr uint32_t kNumMathThreadsPerGroup = 128; 956 | auto kernel = fp8_gemm_kernel 1) ? true : false, kGemmType, l2_hash_bits, l2_optimization, 959 | FORCED_M>; 960 | DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); 961 | 962 | // This will leak but considered OK since it's less memory than the code itself! 963 | // Kernel has a sacred duty to return this memory as zero so it can easily be reused 964 | static int* zeroed_scratch = nullptr; 965 | if (zeroed_scratch == nullptr) { 966 | cudaMalloc(&zeroed_scratch, 256 * sizeof(int)); 967 | cudaMemset(zeroed_scratch, 0, 256 * sizeof(int)); 968 | } 969 | 970 | static bool init_side_index = false; 971 | static param_side_index_t param_sideaware; 972 | if constexpr (l2_optimization) { 973 | if (!init_side_index) { 974 | cudaMemcpy(param_sideaware.sm_side_and_idx, gpu_side_index, MAX_SM * sizeof(unsigned char), cudaMemcpyDeviceToHost); 975 | init_side_index = true; 976 | } 977 | } 978 | assert(reinterpret_cast(gmem_b) % 8192 == 0); 979 | 980 | // Cluster launch 981 | cudaLaunchConfig_t config; 982 | config.blockDim = get_num_threads_per_sm(BLOCK_M); 983 | config.dynamicSmemBytes = smem_size; 984 | config.stream = stream; 985 | 986 | // Clusters for TMA multicast 987 | if constexpr (kNumTMAMulticast <= 1) { 988 | cudaLaunchAttribute attr; 989 | attr.id = cudaLaunchAttributeClusterDimension; 990 | attr.val.clusterDim = {1, 1, 1}; 991 | config.attrs = &attr; 992 | config.numAttrs = 1; 993 | config.gridDim = num_sms; 994 | auto status = cudaLaunchKernelEx(&config, kernel, 995 | gmem_d, gmem_b, scales_b, grouped_layout, zeroed_scratch, shape_m, 996 | tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_d_padded_desc, 997 | 0, num_sms, param_sideaware); 998 | DG_HOST_ASSERT(status == cudaSuccess); 999 | } /*else if ((SHAPE_N % (BLOCK_N * 8)) == 0) { 1000 | // TODO: Add back support for Hybrid Cluster Size with L2 side optimization 1001 | // [...] see older commits [...] 1002 | }*/ else { 1003 | cudaLaunchAttribute attr; 1004 | attr.id = cudaLaunchAttributeClusterDimension; 1005 | attr.val.clusterDim = {1, 1, 1}; 1006 | config.attrs = &attr; 1007 | config.numAttrs = 1; 1008 | 1009 | config.gridDim = num_sms; 1010 | attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; 1011 | config.stream = stream; 1012 | auto status = cudaLaunchKernelEx(&config, kernel, 1013 | gmem_d, gmem_b, scales_b, grouped_layout, zeroed_scratch, shape_m, 1014 | tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_d_padded_desc, 1015 | 0, num_sms, param_sideaware); 1016 | DG_HOST_ASSERT(status == cudaSuccess); 1017 | } 1018 | } 1019 | 1020 | template 1021 | static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { 1022 | return make_2d_tma_desc(global_address, Layout::RowMajor, 1023 | shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); 1024 | } 1025 | 1026 | template 1027 | static CUtensorMap make_2d_tma_b_desc(T* global_address) { 1028 | return make_2d_tma_desc(global_address, Layout::ColMajor, 1029 | SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); 1030 | } 1031 | 1032 | template 1033 | static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { 1034 | // TODO - I'm going to be slightly too honest and admit it's 2AM and I can't remember why I replaced BLOCK_M with 64 1035 | // but otherwise it crashes, so... 1036 | return make_2d_tma_desc(global_address, Layout::RowMajor, 1037 | shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, min(/*BLOCK_M*/ 64, shape_m), BLOCK_N, 1038 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); 1039 | } 1040 | 1041 | template 1042 | static CUtensorMap make_3d_tma_d_desc(T* global_address, uint32_t shape_m) { 1043 | return make_3d_tma_padded_desc(global_address, 1044 | shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, min(/*BLOCK_M*/ 64, shape_m), BLOCK_N, 1045 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); 1046 | } 1047 | 1048 | template 1049 | static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { 1050 | // Make TMA aligned to 16 bytes 1051 | constexpr uint32_t kAlignment = 16 / sizeof(T); 1052 | shape_m = ceil_div(shape_m, kAlignment) * kAlignment; 1053 | 1054 | return make_2d_tma_desc(global_address, Layout::ColMajor, 1055 | shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, 1056 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); 1057 | } 1058 | 1059 | template 1060 | static CUtensorMap make_2d_tma_desc( 1061 | T* global_address, Layout layout, 1062 | uint32_t gmem_rows, uint32_t gmem_cols, 1063 | uint32_t smem_rows, uint32_t smem_cols, 1064 | CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { 1065 | if (layout == Layout::RowMajor) { 1066 | uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; 1067 | uint32_t smem_dim[2] = {smem_cols, smem_rows}; 1068 | return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); 1069 | } else { 1070 | uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; 1071 | uint32_t smem_dim[2] = {smem_rows, smem_cols}; 1072 | return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); 1073 | } 1074 | } 1075 | 1076 | template 1077 | static CUtensorMap make_3d_tma_padded_desc( 1078 | T* global_address, 1079 | uint32_t gmem_rows, uint32_t gmem_cols, 1080 | uint32_t smem_rows, uint32_t smem_cols, 1081 | CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { 1082 | uint32_t padding = (smem_cols == 64 || smem_cols == 96 || smem_cols == 128) ? PADDING_N : 0; 1083 | uint64_t gmem_dim[3] = {smem_cols, gmem_cols/smem_cols, gmem_rows}; 1084 | uint32_t smem_dim[3] = {smem_cols+padding, 1, smem_rows}; 1085 | uint64_t stride_in_bytes[2] = {smem_cols * sizeof(T), gmem_cols * sizeof(T)}; 1086 | return make_3d_tma_copy_desc(global_address, gmem_dim, stride_in_bytes, smem_dim, swizzle_type); 1087 | } 1088 | }; 1089 | 1090 | }; // namespace deep_gemm 1091 | 1092 | #pragma clang diagnostic pop 1093 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/mma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "utils.cuh" 6 | 7 | namespace deep_gemm { 8 | 9 | struct SM90_64x16x32_F32E4M3E4M3_SS { 10 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 11 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 12 | bool scale_d) { 13 | asm volatile("{\n" 14 | ".reg .pred p;\n" 15 | "setp.ne.b32 p, %10, 0;\n" 16 | "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" 17 | "{%0, %1, %2, %3, %4, %5, %6, %7}," 18 | " %8," 19 | " %9," 20 | " p , 1, 1;\n" 21 | "}\n" 22 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) 23 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 24 | } 25 | 26 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 27 | wgmma(desc_a, desc_b, 28 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 29 | scale_d); 30 | } 31 | 32 | static constexpr int M = 64; 33 | static constexpr int N = 16; 34 | static constexpr int K = 32; 35 | static constexpr int kNumAccum = M * N / 128; 36 | }; 37 | 38 | struct SM90_64x24x32_F32E4M3E4M3_SS { 39 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 40 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 41 | float& d08, float& d09, float& d10, float& d11, 42 | bool scale_d) { 43 | asm volatile("{\n" 44 | ".reg .pred p;\n" 45 | "setp.ne.b32 p, %14, 0;\n" 46 | "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3" 47 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 48 | " %8, %9, %10, %11}," 49 | " %12," 50 | " %13," 51 | " p , 1, 1;\n" 52 | "}\n" 53 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 54 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) 55 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 56 | } 57 | 58 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 59 | wgmma(desc_a, desc_b, 60 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 61 | d[8], d[9], d[10], d[11], 62 | scale_d); 63 | } 64 | 65 | static constexpr int M = 64; 66 | static constexpr int N = 24; 67 | static constexpr int K = 32; 68 | static constexpr int kNumAccum = M * N / 128; 69 | }; 70 | 71 | struct SM90_64x32x32_F32E4M3E4M3_SS { 72 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 73 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 74 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 75 | bool scale_d) { 76 | asm volatile("{\n" 77 | ".reg .pred p;\n" 78 | "setp.ne.b32 p, %18, 0;\n" 79 | "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" 80 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 81 | " %8, %9, %10, %11, %12, %13, %14, %15}," 82 | " %16," 83 | " %17," 84 | " p , 1, 1;\n" 85 | "}\n" 86 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 87 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) 88 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 89 | } 90 | 91 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 92 | wgmma(desc_a, desc_b, 93 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 94 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 95 | scale_d); 96 | } 97 | 98 | static constexpr int M = 64; 99 | static constexpr int N = 32; 100 | static constexpr int K = 32; 101 | static constexpr int kNumAccum = M * N / 128; 102 | }; 103 | 104 | struct SM90_64x40x32_F32E4M3E4M3_SS { 105 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 106 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 107 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 108 | float& d16, float& d17, float& d18, float& d19, 109 | bool scale_d) { 110 | asm volatile("{\n" 111 | ".reg .pred p;\n" 112 | "setp.ne.b32 p, %22, 0;\n" 113 | "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3" 114 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 115 | " %8, %9, %10, %11, %12, %13, %14, %15, " 116 | " %16, %17, %18, %19}," 117 | " %20," 118 | " %21," 119 | " p , 1, 1;\n" 120 | "}\n" 121 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 122 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 123 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) 124 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 125 | } 126 | 127 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 128 | wgmma(desc_a, desc_b, 129 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 130 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 131 | d[16], d[17], d[18], d[19], 132 | scale_d); 133 | } 134 | 135 | static constexpr int M = 64; 136 | static constexpr int N = 40; 137 | static constexpr int K = 32; 138 | static constexpr int kNumAccum = M * N / 128; 139 | }; 140 | 141 | struct SM90_64x48x32_F32E4M3E4M3_SS { 142 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 143 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 144 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 145 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 146 | bool scale_d) { 147 | asm volatile("{\n" 148 | ".reg .pred p;\n" 149 | "setp.ne.b32 p, %26, 0;\n" 150 | "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" 151 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 152 | " %8, %9, %10, %11, %12, %13, %14, %15, " 153 | " %16, %17, %18, %19, %20, %21, %22, %23}," 154 | " %24," 155 | " %25," 156 | " p , 1, 1;\n" 157 | "}\n" 158 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 159 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 160 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) 161 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 162 | } 163 | 164 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 165 | wgmma(desc_a, desc_b, 166 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 167 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 168 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 169 | scale_d); 170 | } 171 | 172 | static constexpr int M = 64; 173 | static constexpr int N = 48; 174 | static constexpr int K = 32; 175 | static constexpr int kNumAccum = M * N / 128; 176 | }; 177 | 178 | struct SM90_64x56x32_F32E4M3E4M3_SS { 179 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 180 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 181 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 182 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 183 | float& d24, float& d25, float& d26, float& d27, 184 | bool scale_d) { 185 | asm volatile("{\n" 186 | ".reg .pred p;\n" 187 | "setp.ne.b32 p, %30, 0;\n" 188 | "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" 189 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 190 | " %8, %9, %10, %11, %12, %13, %14, %15, " 191 | " %16, %17, %18, %19, %20, %21, %22, %23, " 192 | " %24, %25, %26, %27}, " 193 | " %28," 194 | " %29," 195 | " p , 1, 1;\n" 196 | "}\n" 197 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 198 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 199 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 200 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) 201 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 202 | } 203 | 204 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 205 | wgmma(desc_a, desc_b, 206 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 207 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 208 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 209 | d[24], d[25], d[26], d[27], 210 | scale_d); 211 | } 212 | 213 | static constexpr int M = 64; 214 | static constexpr int N = 56; 215 | static constexpr int K = 32; 216 | static constexpr int kNumAccum = M * N / 128; 217 | }; 218 | 219 | struct SM90_64x64x32_F32E4M3E4M3_SS { 220 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 221 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 222 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 223 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 224 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 225 | bool scale_d) { 226 | asm volatile("{\n" 227 | ".reg .pred p;\n" 228 | "setp.ne.b32 p, %34, 0;\n" 229 | "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" 230 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 231 | " %8, %9, %10, %11, %12, %13, %14, %15, " 232 | " %16, %17, %18, %19, %20, %21, %22, %23, " 233 | " %24, %25, %26, %27, %28, %29, %30, %31}, " 234 | " %32," 235 | " %33," 236 | " p , 1, 1;\n" 237 | "}\n" 238 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 239 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 240 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 241 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) 242 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 243 | } 244 | 245 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 246 | wgmma(desc_a, desc_b, 247 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 248 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 249 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 250 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 251 | scale_d); 252 | } 253 | 254 | static constexpr int M = 64; 255 | static constexpr int N = 64; 256 | static constexpr int K = 32; 257 | static constexpr int kNumAccum = M * N / 128; 258 | }; 259 | 260 | struct SM90_64x72x32_F32E4M3E4M3_SS { 261 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 262 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 263 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 264 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 265 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 266 | float& d32, float& d33, float& d34, float& d35, 267 | bool scale_d) { 268 | asm volatile("{\n" 269 | ".reg .pred p;\n" 270 | "setp.ne.b32 p, %38, 0;\n" 271 | "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3" 272 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 273 | " %8, %9, %10, %11, %12, %13, %14, %15, " 274 | " %16, %17, %18, %19, %20, %21, %22, %23, " 275 | " %24, %25, %26, %27, %28, %29, %30, %31, " 276 | " %32, %33, %34, %35}, " 277 | " %36," 278 | " %37," 279 | " p , 1, 1;\n" 280 | "}\n" 281 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 282 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 283 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 284 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 285 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) 286 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 287 | } 288 | 289 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 290 | wgmma(desc_a, desc_b, 291 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 292 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 293 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 294 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 295 | d[32], d[33], d[34], d[35], 296 | scale_d); 297 | } 298 | 299 | static constexpr int M = 64; 300 | static constexpr int N = 72; 301 | static constexpr int K = 32; 302 | static constexpr int kNumAccum = M * N / 128; 303 | }; 304 | 305 | struct SM90_64x80x32_F32E4M3E4M3_SS { 306 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 307 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 308 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 309 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 310 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 311 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 312 | bool scale_d) { 313 | asm volatile("{\n" 314 | ".reg .pred p;\n" 315 | "setp.ne.b32 p, %42, 0;\n" 316 | "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3" 317 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 318 | " %8, %9, %10, %11, %12, %13, %14, %15, " 319 | " %16, %17, %18, %19, %20, %21, %22, %23, " 320 | " %24, %25, %26, %27, %28, %29, %30, %31, " 321 | " %32, %33, %34, %35, %36, %37, %38, %39}, " 322 | " %40," 323 | " %41," 324 | " p , 1, 1;\n" 325 | "}\n" 326 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 327 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 328 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 329 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 330 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) 331 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 332 | } 333 | 334 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 335 | wgmma(desc_a, desc_b, 336 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 337 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 338 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 339 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 340 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 341 | scale_d); 342 | } 343 | 344 | static constexpr int M = 64; 345 | static constexpr int N = 80; 346 | static constexpr int K = 32; 347 | static constexpr int kNumAccum = M * N / 128; 348 | }; 349 | 350 | struct SM90_64x88x32_F32E4M3E4M3_SS { 351 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 352 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 353 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 354 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 355 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 356 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 357 | float& d40, float& d41, float& d42, float& d43, 358 | bool scale_d) { 359 | asm volatile("{\n" 360 | ".reg .pred p;\n" 361 | "setp.ne.b32 p, %46, 0;\n" 362 | "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3" 363 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 364 | " %8, %9, %10, %11, %12, %13, %14, %15, " 365 | " %16, %17, %18, %19, %20, %21, %22, %23, " 366 | " %24, %25, %26, %27, %28, %29, %30, %31, " 367 | " %32, %33, %34, %35, %36, %37, %38, %39, " 368 | " %40, %41, %42, %43}, " 369 | " %44," 370 | " %45," 371 | " p , 1, 1;\n" 372 | "}\n" 373 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 374 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 375 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 376 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 377 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 378 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) 379 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 380 | } 381 | 382 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 383 | wgmma(desc_a, desc_b, 384 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 385 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 386 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 387 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 388 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 389 | d[40], d[41], d[42], d[43], 390 | scale_d); 391 | } 392 | 393 | static constexpr int M = 64; 394 | static constexpr int N = 88; 395 | static constexpr int K = 32; 396 | static constexpr int kNumAccum = M * N / 128; 397 | }; 398 | 399 | struct SM90_64x96x32_F32E4M3E4M3_SS { 400 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 401 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 402 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 403 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 404 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 405 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 406 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 407 | bool scale_d) { 408 | asm volatile("{\n" 409 | ".reg .pred p;\n" 410 | "setp.ne.b32 p, %50, 0;\n" 411 | "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" 412 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 413 | " %8, %9, %10, %11, %12, %13, %14, %15, " 414 | " %16, %17, %18, %19, %20, %21, %22, %23, " 415 | " %24, %25, %26, %27, %28, %29, %30, %31, " 416 | " %32, %33, %34, %35, %36, %37, %38, %39, " 417 | " %40, %41, %42, %43, %44, %45, %46, %47}, " 418 | " %48," 419 | " %49," 420 | " p , 1, 1;\n" 421 | "}\n" 422 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 423 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 424 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 425 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 426 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 427 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) 428 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 429 | } 430 | 431 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 432 | wgmma(desc_a, desc_b, 433 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 434 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 435 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 436 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 437 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 438 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 439 | scale_d); 440 | } 441 | 442 | static constexpr int M = 64; 443 | static constexpr int N = 96; 444 | static constexpr int K = 32; 445 | static constexpr int kNumAccum = M * N / 128; 446 | }; 447 | 448 | struct SM90_64x104x32_F32E4M3E4M3_SS { 449 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 450 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 451 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 452 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 453 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 454 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 455 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 456 | float& d48, float& d49, float& d50, float& d51, 457 | bool scale_d) { 458 | asm volatile("{\n" 459 | ".reg .pred p;\n" 460 | "setp.ne.b32 p, %54, 0;\n" 461 | "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3" 462 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 463 | " %8, %9, %10, %11, %12, %13, %14, %15, " 464 | " %16, %17, %18, %19, %20, %21, %22, %23, " 465 | " %24, %25, %26, %27, %28, %29, %30, %31, " 466 | " %32, %33, %34, %35, %36, %37, %38, %39, " 467 | " %40, %41, %42, %43, %44, %45, %46, %47, " 468 | " %48, %49, %50, %51}, " 469 | " %52," 470 | " %53," 471 | " p , 1, 1;\n" 472 | "}\n" 473 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 474 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 475 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 476 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 477 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 478 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), 479 | "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) 480 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 481 | } 482 | 483 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 484 | wgmma(desc_a, desc_b, 485 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 486 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 487 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 488 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 489 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 490 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 491 | d[48], d[49], d[50], d[51], 492 | scale_d); 493 | } 494 | 495 | static constexpr int M = 64; 496 | static constexpr int N = 104; 497 | static constexpr int K = 32; 498 | static constexpr int kNumAccum = M * N / 128; 499 | }; 500 | 501 | struct SM90_64x112x32_F32E4M3E4M3_SS { 502 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 503 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 504 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 505 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 506 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 507 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 508 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 509 | float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, 510 | bool scale_d) { 511 | asm volatile("{\n" 512 | ".reg .pred p;\n" 513 | "setp.ne.b32 p, %58, 0;\n" 514 | "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" 515 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 516 | " %8, %9, %10, %11, %12, %13, %14, %15, " 517 | " %16, %17, %18, %19, %20, %21, %22, %23, " 518 | " %24, %25, %26, %27, %28, %29, %30, %31, " 519 | " %32, %33, %34, %35, %36, %37, %38, %39, " 520 | " %40, %41, %42, %43, %44, %45, %46, %47, " 521 | " %48, %49, %50, %51, %52, %53, %54, %55}, " 522 | " %56," 523 | " %57," 524 | " p , 1, 1;\n" 525 | "}\n" 526 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 527 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 528 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 529 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 530 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 531 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), 532 | "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) 533 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 534 | } 535 | 536 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 537 | wgmma(desc_a, desc_b, 538 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 539 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 540 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 541 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 542 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 543 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 544 | d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], 545 | scale_d); 546 | } 547 | 548 | static constexpr int M = 64; 549 | static constexpr int N = 112; 550 | static constexpr int K = 32; 551 | static constexpr int kNumAccum = M * N / 128; 552 | }; 553 | 554 | struct SM90_64x120x32_F32E4M3E4M3_SS { 555 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 556 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 557 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 558 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 559 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 560 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 561 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 562 | float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, 563 | float& d56, float& d57, float& d58, float& d59, 564 | bool scale_d) { 565 | asm volatile("{\n" 566 | ".reg .pred p;\n" 567 | "setp.ne.b32 p, %62, 0;\n" 568 | "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3" 569 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 570 | " %8, %9, %10, %11, %12, %13, %14, %15, " 571 | " %16, %17, %18, %19, %20, %21, %22, %23, " 572 | " %24, %25, %26, %27, %28, %29, %30, %31, " 573 | " %32, %33, %34, %35, %36, %37, %38, %39, " 574 | " %40, %41, %42, %43, %44, %45, %46, %47, " 575 | " %48, %49, %50, %51, %52, %53, %54, %55, " 576 | " %56, %57, %58, %59}, " 577 | " %60," 578 | " %61," 579 | " p , 1, 1;\n" 580 | "}\n" 581 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 582 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 583 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 584 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 585 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 586 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), 587 | "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), 588 | "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) 589 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 590 | } 591 | 592 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 593 | wgmma(desc_a, desc_b, 594 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 595 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 596 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 597 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 598 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 599 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 600 | d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], 601 | d[56], d[57], d[58], d[59], 602 | scale_d); 603 | } 604 | 605 | static constexpr int M = 64; 606 | static constexpr int N = 120; 607 | static constexpr int K = 32; 608 | static constexpr int kNumAccum = M * N / 128; 609 | }; 610 | 611 | struct SM90_64x128x32_F32E4M3E4M3_SS { 612 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 613 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 614 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 615 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 616 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 617 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 618 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 619 | float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, 620 | float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, 621 | bool scale_d) { 622 | asm volatile("{\n" 623 | ".reg .pred p;\n" 624 | "setp.ne.b32 p, %66, 0;\n" 625 | "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" 626 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 627 | " %8, %9, %10, %11, %12, %13, %14, %15, " 628 | " %16, %17, %18, %19, %20, %21, %22, %23, " 629 | " %24, %25, %26, %27, %28, %29, %30, %31, " 630 | " %32, %33, %34, %35, %36, %37, %38, %39, " 631 | " %40, %41, %42, %43, %44, %45, %46, %47, " 632 | " %48, %49, %50, %51, %52, %53, %54, %55, " 633 | " %56, %57, %58, %59, %60, %61, %62, %63}, " 634 | " %64," 635 | " %65," 636 | " p , 1, 1;\n" 637 | "}\n" 638 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 639 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 640 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 641 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 642 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 643 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), 644 | "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), 645 | "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) 646 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 647 | } 648 | 649 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 650 | wgmma(desc_a, desc_b, 651 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 652 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 653 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 654 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 655 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 656 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 657 | d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], 658 | d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], 659 | scale_d); 660 | } 661 | 662 | static constexpr int M = 64; 663 | static constexpr int N = 128; 664 | static constexpr int K = 32; 665 | static constexpr int kNumAccum = M * N / 128; 666 | }; 667 | 668 | struct SM90_64x192x32_F32E4M3E4M3_SS { 669 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, 670 | float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, 671 | float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, 672 | float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, 673 | float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, 674 | float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, 675 | float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, 676 | float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, 677 | float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, 678 | float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, 679 | float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, 680 | float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, 681 | float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, 682 | bool scale_d) { 683 | asm volatile("{\n" 684 | ".reg .pred p;\n" 685 | "setp.ne.b32 p, %98, 0;\n" 686 | "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" 687 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 688 | " %8, %9, %10, %11, %12, %13, %14, %15, " 689 | " %16, %17, %18, %19, %20, %21, %22, %23, " 690 | " %24, %25, %26, %27, %28, %29, %30, %31, " 691 | " %32, %33, %34, %35, %36, %37, %38, %39, " 692 | " %40, %41, %42, %43, %44, %45, %46, %47, " 693 | " %48, %49, %50, %51, %52, %53, %54, %55, " 694 | " %56, %57, %58, %59, %60, %61, %62, %63, " 695 | " %64, %65, %66, %67, %68, %69, %70, %71, " 696 | " %72, %73, %74, %75, %76, %77, %78, %79, " 697 | " %80, %81, %82, %83, %84, %85, %86, %87, " 698 | " %88, %89, %90, %91, %92, %93, %94, %95}, " 699 | " %96," 700 | " %97," 701 | " p , 1, 1;\n" 702 | "}\n" 703 | : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), 704 | "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), 705 | "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), 706 | "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), 707 | "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), 708 | "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), 709 | "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), 710 | "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), 711 | "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), 712 | "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), 713 | "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), 714 | "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) 715 | : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); 716 | } 717 | 718 | __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { 719 | wgmma(desc_a, desc_b, 720 | d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], 721 | d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], 722 | d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], 723 | d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], 724 | d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], 725 | d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], 726 | d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], 727 | d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], 728 | d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], 729 | d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], 730 | d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], 731 | d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], 732 | scale_d); 733 | } 734 | 735 | static constexpr int M = 64; 736 | static constexpr int N = 192; 737 | static constexpr int K = 32; 738 | static constexpr int kNumAccum = M * N / 128; 739 | }; 740 | 741 | template 742 | struct SM90_U32x2_STSM_N { 743 | __device__ __forceinline__ static void 744 | copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { 745 | const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; 746 | asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" 747 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); 748 | } 749 | }; 750 | 751 | template 752 | struct SM90_U32x4_STSM_N { 753 | __device__ __forceinline__ static void 754 | copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { 755 | const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), 756 | *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; 757 | asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" 758 | :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); 759 | } 760 | }; 761 | 762 | __device__ void warpgroup_arrive() { 763 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 764 | } 765 | 766 | __device__ void warpgroup_commit_batch() { 767 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 768 | } 769 | 770 | __device__ void warpgroup_fence_operand(float& reg) { 771 | asm volatile("" : "+f"(reg) :: "memory"); 772 | } 773 | 774 | __forceinline__ __device__ uint32_t get_lane_id() { 775 | uint32_t lane_id; 776 | asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); 777 | return lane_id; 778 | } 779 | 780 | __device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { 781 | uint32_t ret; 782 | asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 783 | return ret; 784 | } 785 | 786 | __device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { 787 | int4 ret; 788 | asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); 789 | return ret; 790 | } 791 | 792 | __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { 793 | float ret; 794 | asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); 795 | return ret; 796 | } 797 | 798 | __device__ __forceinline__ void st_shared(const float* ptr, float val) { 799 | asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); 800 | } 801 | 802 | __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { 803 | asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); 804 | } 805 | 806 | template 807 | __device__ void warpgroup_wait() { 808 | DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); 809 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); 810 | } 811 | 812 | union GmmaDescriptor { 813 | __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} 814 | 815 | __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} 816 | 817 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} 818 | 819 | __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} 820 | 821 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { 822 | desc_ = t.desc_; 823 | return *this; 824 | } 825 | 826 | __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { 827 | desc_ = t.desc_; 828 | return *this; 829 | } 830 | 831 | uint64_t desc_; 832 | uint32_t reg32_[2]; 833 | uint16_t reg16_[4]; 834 | 835 | struct { 836 | uint16_t start_address_: 14, : 2; 837 | uint16_t leading_byte_offset_: 14, : 2; 838 | uint16_t stride_byte_offset_: 14, : 2; 839 | uint8_t : 1, base_offset_: 3, : 4; 840 | uint8_t : 6, layout_type_: 2; 841 | } bitfield; 842 | 843 | // Decay to an `uint64_t` 844 | __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } 845 | }; 846 | 847 | template 848 | __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, 849 | int leading_byte_offset = 0, 850 | int stride_byte_offset = 1024) { 851 | GmmaDescriptor desc; 852 | auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 853 | desc.bitfield.start_address_ = uint_ptr >> 4; 854 | desc.bitfield.layout_type_ = layout_type; 855 | desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; 856 | desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; 857 | desc.bitfield.base_offset_ = 0; 858 | return desc; 859 | } 860 | 861 | template 862 | struct FP8MMASelector { 863 | static constexpr auto select_type() { 864 | if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS(); 865 | if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS(); 866 | if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS(); 867 | if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS(); 868 | if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS(); 869 | if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS(); 870 | if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS(); 871 | if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS(); 872 | if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS(); 873 | if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS(); 874 | if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS(); 875 | if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS(); 876 | if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); 877 | if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); 878 | if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); 879 | if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); 880 | } 881 | 882 | using type = decltype(select_type()); 883 | }; 884 | 885 | } // namespace deep_gemm 886 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/reorder_b.cuh: -------------------------------------------------------------------------------- 1 | // TODO - rewrite all of this!!! 2 | // need to change how L2 SM/page side info is tested and passed around 3 | // (could be sideband or hidden inside parity bits, but careful page side implicitly depends on SM side!) 4 | // This should be fused with FP8 conversion/transpose kernels if possible 5 | // and done persistently so it's written from a SM on the correct side 6 | 7 | 8 | #pragma clang diagnostic push 9 | #pragma clang diagnostic ignored "-Wunknown-attributes" 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "tma_utils.cuh" 20 | #include "utils.cuh" 21 | 22 | namespace deep_gemm { 23 | 24 | enum class Layout { 25 | RowMajor, 26 | ColMajor 27 | }; 28 | enum class GemmType { 29 | Normal, 30 | GroupedContiguous, 31 | GroupedMasked 32 | }; 33 | 34 | // TODO: L2 side aware B with 128x32 blocks (contiguous 4KiB) 35 | template 36 | __global__ void optimize_B(__nv_fp8_e4m3* gmem_b_out, __nv_fp8_e4m3* gmem_b_in, int* grouped_layout, 37 | const __grid_constant__ CUtensorMap tensor_map_b, 38 | int shape_n, int shape_k, int num_groups) { 39 | // Currently don't support L2 side optimization for grouped masked GEMM (possible in theory?) 40 | constexpr bool L2_SIDE_OPTIMIZATION = l2_optimization && (kGemmType != GemmType::GroupedMasked); 41 | uint32_t shape_n_half = shape_n / 2; // this works because N%64==0 and fp8_gemm.cuh loads in chunks of 32xN 42 | 43 | using Barrier = cutlass::arch::ClusterTransactionBarrier; 44 | int laneid; 45 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 46 | asm volatile("mov.u32 %0, %laneid;" : "=r"(laneid)); 47 | 48 | extern __shared__ __align__(1024) uint8_t smem_buffer[]; 49 | Barrier* barrier = reinterpret_cast(smem_buffer + 1024); 50 | 51 | if constexpr (L2_SIDE_OPTIMIZATION) { 52 | int blocks_per_group = (shape_n/8U) * (shape_k/BLOCK_K); 53 | int group_idx = blockIdx.x / blocks_per_group; 54 | int idx_in_group = blockIdx.x % blocks_per_group; 55 | int n = (idx_in_group * 8) % shape_n; 56 | int k = ((idx_in_group * 8) / shape_n) * 128; 57 | 58 | if (threadIdx.x == 0 && laneid == 0) { 59 | barrier->init(1); 60 | cutlass::arch::fence_view_async_shared(); 61 | tma_copy(&tensor_map_b, reinterpret_cast(barrier), smem_buffer, k, n + (group_idx * shape_n)); 62 | barrier->arrive_and_expect_tx(1024); 63 | } 64 | __syncthreads(); 65 | barrier->wait(0); 66 | 67 | int4 data_int4 = reinterpret_cast(smem_buffer)[threadIdx.x]; 68 | __nv_fp8_e4m3 *data_fp8 = reinterpret_cast<__nv_fp8_e4m3*>(&data_int4); 69 | 70 | int n_side = (n >= shape_n_half) ? 1 : 0; 71 | int n_half = n_side ? (n - shape_n_half) : n; 72 | int n_dst_base = (n_half & 31) + (n_half & ~31) * 2; 73 | int offset = (n_dst_base * 128) + (k * shape_n); 74 | 75 | offset += group_idx * shape_n * shape_k; 76 | 77 | int local_side = __popc(reinterpret_cast(&gmem_b_out[offset]) & l2_hash_bits) & 1; 78 | int upper_4kib = n_side ^ local_side ^ 1; // extra ^1 with new sideaware.cu because bit 21 is in the hash 79 | 80 | int4* address = reinterpret_cast(gmem_b_out + (offset + upper_4kib * 4096)); 81 | address[threadIdx.x] = data_int4; 82 | } else { 83 | // simple memcpy for non-optimized case 84 | int offset = tid * 16; 85 | uint4* out4 = reinterpret_cast(gmem_b_out + offset); 86 | *out4 = *(reinterpret_cast(gmem_b_in + offset)); 87 | } 88 | } 89 | 90 | template 91 | class ReorderB { 92 | private: 93 | using Barrier = cuda::barrier; 94 | 95 | public: 96 | ReorderB() = default; 97 | 98 | static void run(__nv_fp8_e4m3* gmem_b_out, __nv_fp8_e4m3* gmem_b_in, int* grouped_layout, const CUtensorMap& tma_b_desc_in, 99 | uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_sms, int num_groups) { 100 | // Is the address 8KiB aligned? 101 | uint64_t address = reinterpret_cast(gmem_b_out); 102 | assert(address % 8192 == 0); 103 | 104 | // Calculate number of tiles and smem size 105 | uint32_t kNumTiles = ceil_div(shape_n, 8U) * ceil_div(shape_k, BLOCK_K) * num_groups; 106 | constexpr uint32_t smem_size = BLOCK_K * 8 * sizeof(__nv_fp8_e4m3) + 1024; 107 | 108 | auto kernel = optimize_B; 109 | DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); 110 | 111 | // Launch 112 | cudaLaunchConfig_t config; 113 | config.blockDim = 64; 114 | config.dynamicSmemBytes = smem_size; 115 | config.stream = stream; 116 | config.gridDim = kNumTiles; 117 | 118 | cudaLaunchAttribute attr; 119 | attr.id = cudaLaunchAttributeClusterDimension; 120 | attr.val.clusterDim = {1, 1, 1}; 121 | config.attrs = &attr; 122 | config.numAttrs = 1; 123 | 124 | auto status = cudaLaunchKernelEx(&config, kernel, gmem_b_out, gmem_b_in, grouped_layout, tma_b_desc_in, shape_n, shape_k, num_groups); 125 | DG_HOST_ASSERT(status == cudaSuccess); 126 | 127 | cudaDeviceSynchronize(); 128 | } 129 | 130 | template 131 | static CUtensorMap make_2d_tma_b_desc(T* global_address, int shape_n, int shape_k, int num_groups=1) { 132 | return make_2d_tma_desc(global_address, Layout::ColMajor, 133 | shape_k, shape_n * (kGemmType != GemmType::Normal ? num_groups : 1), BLOCK_K, 8); 134 | } 135 | 136 | template 137 | static CUtensorMap make_2d_tma_desc( 138 | T* global_address, Layout layout, 139 | uint32_t gmem_rows, uint32_t gmem_cols, 140 | uint32_t smem_rows, uint32_t smem_cols, 141 | CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { 142 | if (layout == Layout::RowMajor) { 143 | uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; 144 | uint32_t smem_dim[2] = {smem_cols, smem_rows}; 145 | return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); 146 | } else { 147 | uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; 148 | uint32_t smem_dim[2] = {smem_rows, smem_cols}; 149 | return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); 150 | } 151 | } 152 | }; 153 | 154 | }; // namespace deep_gemm 155 | 156 | #pragma clang diagnostic pop 157 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/scheduler.cuh: -------------------------------------------------------------------------------- 1 | #include "utils.cuh" 2 | 3 | namespace deep_gemm { 4 | 5 | enum class GemmType { 6 | Normal, 7 | GroupedContiguous, 8 | GroupedMasked 9 | }; 10 | 11 | #pragma clang diagnostic push 12 | #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" 13 | template 19 | struct Scheduler { 20 | int current_iter = -1; 21 | uint32_t num_aligned_m_blocks; 22 | 23 | // For normal GEMM 24 | // Maybe not used in the masked grouped GEMM 25 | uint32_t num_blocks; 26 | 27 | // For grouped GEMM 28 | int* grouped_layout; 29 | // Only used for masked layout 30 | uint32_t curr_group_idx, curr_cumsum; 31 | // with hybrid cluster sizes, we can't use blockIdx.x/gridDim.x directly 32 | // e.g. with 15 clusters of 8 + 4 clusters of 2, latter will have: block_idx = 120+blockIdx.x 33 | int block_idx, grid_size; 34 | int n_block_offset; 35 | __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, 36 | int* grouped_layout = nullptr, 37 | int block_idx = -1, int grid_size = -1, 38 | int n_block_offset = 0) { 39 | num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); 40 | if constexpr (kGemmType == GemmType::Normal) { 41 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 42 | } else if (kGemmType == GemmType::GroupedContiguous) { 43 | num_blocks = num_aligned_m_blocks * kNumNBlocks; 44 | this->grouped_layout = grouped_layout; 45 | } else if (kGemmType == GemmType::GroupedMasked) { 46 | curr_group_idx = curr_cumsum = 0; 47 | this->grouped_layout = grouped_layout; 48 | } 49 | this->block_idx = block_idx >= 0 ? block_idx : blockIdx.x; 50 | this->grid_size = grid_size >= 0 ? grid_size : gridDim.x; 51 | this->n_block_offset = n_block_offset; 52 | } 53 | 54 | __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { 55 | // TODO: check for this statically host side (since cluster size is now dynamic) 56 | //DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); 57 | 58 | // Swizzle for better L2 usages 59 | auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; // HACK: TODO: temporary optimization for m=64 60 | auto group_idx = block_idx / num_blocks_per_group; 61 | auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; 62 | auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); 63 | auto in_group_idx = block_idx % num_blocks_per_group; 64 | m_block_idx = in_group_idx / num_n_blocks_in_group; 65 | n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; 66 | } 67 | 68 | template 69 | __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, 70 | const uint32_t& block_idx, const uint32_t& m_block_idx=0) { 71 | if constexpr (kGemmType == GemmType::Normal) { 72 | return block_idx * block_size; 73 | } else if (kGemmType == GemmType::GroupedContiguous) { 74 | auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); 75 | return offset * shape_dim + block_idx * block_size; 76 | } else if (kGemmType == GemmType::GroupedMasked) { 77 | return curr_group_idx * shape_dim + block_idx * block_size; 78 | } 79 | } 80 | 81 | __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { 82 | const auto next_block_idx = (++ current_iter) * grid_size + block_idx; 83 | 84 | if constexpr (kGemmType == GemmType::GroupedMasked) { 85 | uint32_t num_m_blocks; 86 | while (true) { 87 | // End of the task 88 | if (curr_group_idx == kNumGroups) { 89 | m_block_idx = 0xFFFFFFFF; 90 | return false; 91 | } 92 | 93 | // Within current group 94 | num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); 95 | auto current_m_block_cumsum = curr_cumsum + num_m_blocks; 96 | if (next_block_idx < current_m_block_cumsum * kNumNBlocks) 97 | break; 98 | 99 | // Move to check the next group 100 | curr_group_idx ++, curr_cumsum = current_m_block_cumsum; 101 | } 102 | 103 | get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); 104 | } else { 105 | if (next_block_idx >= num_blocks) { 106 | m_block_idx = 0xFFFFFFFF; 107 | return false; 108 | } 109 | 110 | get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); 111 | } 112 | n_block_idx += n_block_offset; 113 | return true; 114 | } 115 | }; 116 | #pragma clang diagnostic pop 117 | 118 | } // namespace deep_gemm 119 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/tma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "utils.cuh" 11 | 12 | namespace deep_gemm { 13 | 14 | template 15 | constexpr CUtensorMapDataType get_CUtensorMapDataType() { 16 | if constexpr (std::is_same::value) { 17 | return CU_TENSOR_MAP_DATA_TYPE_UINT8; 18 | } else if constexpr (std::is_same::value) { 19 | return CU_TENSOR_MAP_DATA_TYPE_UINT8; 20 | } else if constexpr (std::is_same::value) { 21 | return CU_TENSOR_MAP_DATA_TYPE_UINT8; 22 | } else if constexpr (std::is_same::value) { 23 | return CU_TENSOR_MAP_DATA_TYPE_UINT16; 24 | } else if constexpr (std::is_same::value) { 25 | return CU_TENSOR_MAP_DATA_TYPE_UINT32; 26 | } else if constexpr (std::is_same::value) { 27 | return CU_TENSOR_MAP_DATA_TYPE_UINT64; 28 | } else if constexpr (std::is_same::value) { 29 | return CU_TENSOR_MAP_DATA_TYPE_INT32; 30 | } else if constexpr (std::is_same::value) { 31 | return CU_TENSOR_MAP_DATA_TYPE_INT64; 32 | } else if constexpr (std::is_same::value) { 33 | return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; 34 | } else if constexpr (std::is_same::value) { 35 | return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; 36 | } else if constexpr (std::is_same::value) { 37 | return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; 38 | } else if constexpr (std::is_same::value) { 39 | return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; 40 | } 41 | } 42 | 43 | PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { 44 | // Get pointer to `cuTensorMapEncodeTiled` 45 | cudaDriverEntryPointQueryResult driver_status; 46 | void* cuTensorMapEncodeTiled_ptr = nullptr; 47 | 48 | #if CUDA_VERSION >= 12050 49 | cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, 50 | cudaEnableDefault, &driver_status); 51 | #else 52 | cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 53 | cudaEnableDefault, &driver_status); 54 | #endif 55 | 56 | if (driver_status != cudaDriverEntryPointSuccess) 57 | throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); 58 | return reinterpret_cast(cuTensorMapEncodeTiled_ptr); 59 | } 60 | 61 | template 62 | CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], 63 | uint64_t stride_in_bytes, uint32_t smem_dim[2], 64 | CUtensorMapSwizzle swizzle_type, 65 | PFN_cuTensorMapEncodeTiled encode_func = nullptr) { 66 | CUtensorMap tensor_map{}; 67 | constexpr uint32_t rank = 2; 68 | uint64_t global_stride[rank - 1] = {stride_in_bytes}; 69 | uint32_t elem_strides[rank] = {1, 1}; 70 | 71 | if (encode_func == nullptr) 72 | encode_func = get_cuTensorMapEncodeTiled(); 73 | 74 | auto result = encode_func( 75 | &tensor_map, get_CUtensorMapDataType::type>(), rank, 76 | global_address, gmem_dim, global_stride, smem_dim, elem_strides, 77 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, 78 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 79 | // No longer need 256B promotion for B since we load consecutive 4096B 80 | // TODO - do we want to add this back only for A & other tensors? 81 | //CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, 82 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 83 | DG_HOST_ASSERT(result == CUDA_SUCCESS); 84 | return tensor_map; 85 | } 86 | 87 | template 88 | CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3], 89 | uint64_t stride_in_bytes[2], uint32_t smem_dim[3], 90 | CUtensorMapSwizzle swizzle_type, 91 | PFN_cuTensorMapEncodeTiled encode_func = nullptr) { 92 | CUtensorMap tensor_map{}; 93 | constexpr uint32_t rank = 3; 94 | uint64_t global_stride[rank - 1] = {stride_in_bytes[0], stride_in_bytes[1]}; 95 | uint32_t elem_strides[rank] = {1, 1, 1}; 96 | 97 | if (encode_func == nullptr) 98 | encode_func = get_cuTensorMapEncodeTiled(); 99 | 100 | auto result = encode_func( 101 | &tensor_map, get_CUtensorMapDataType::type>(), rank, 102 | global_address, gmem_dim, global_stride, smem_dim, elem_strides, 103 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, 104 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 105 | //CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, 106 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 107 | DG_HOST_ASSERT(result == CUDA_SUCCESS); 108 | return tensor_map; 109 | } 110 | 111 | template 112 | __device__ __forceinline__ void 113 | tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, 114 | int32_t const& crd_0, int32_t const& crd_1, uint32_t num_multicast = 1) { 115 | constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); 116 | if constexpr (!kTMAMulticastEnabled) { 117 | cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); 118 | } else { // need to check for rank_in_cluster() == 0 outside the function(/loop) 119 | cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); 120 | } 121 | } 122 | 123 | } // namespace deep_gemm 124 | -------------------------------------------------------------------------------- /deep_gemm/include/deep_gemm/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #ifdef __CLION_IDE__ 6 | __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } 7 | #define printf host_device_printf 8 | #endif 9 | 10 | class AssertionException : public std::exception { 11 | private: 12 | std::string message{}; 13 | 14 | public: 15 | explicit AssertionException(const std::string& message) : message(message) {} 16 | 17 | const char *what() const noexcept override { return message.c_str(); } 18 | }; 19 | 20 | #ifndef DG_HOST_ASSERT 21 | #define DG_HOST_ASSERT(cond) \ 22 | do { \ 23 | if (not (cond)) { \ 24 | printf("Assertion failed: %s:%d, condition: %s\n", \ 25 | __FILE__, __LINE__, #cond); \ 26 | throw AssertionException("Assertion failed: " #cond); \ 27 | } \ 28 | } while (0) 29 | #endif 30 | 31 | #ifndef DG_DEVICE_ASSERT 32 | #define DG_DEVICE_ASSERT(cond) \ 33 | do { \ 34 | if (not (cond)) { \ 35 | printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ 36 | asm("trap;"); \ 37 | } \ 38 | } while (0) 39 | #endif 40 | 41 | #ifndef DG_STATIC_ASSERT 42 | #define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) 43 | #endif 44 | 45 | template 46 | __device__ __host__ constexpr T ceil_div(T a, T b) { 47 | return (a + b - 1) / b; 48 | } 49 | 50 | __device__ void elect_or_exit() { 51 | // this helps the compiler not be silly (tested on 12.8) 52 | // threadIdx.x == constant should be enough but it isn't... :( 53 | asm volatile ( 54 | "{\n\t" 55 | ".reg .pred P;\n\t" 56 | "elect.sync _|P, 0xFFFFFFFF;\n\t" 57 | "@!P exit;\n\t" 58 | "}\n" :: ); 59 | } 60 | -------------------------------------------------------------------------------- /deep_gemm/jit/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import get_nvcc_compiler, build 2 | from .template import cpp_format, generate 3 | from .runtime import Runtime 4 | -------------------------------------------------------------------------------- /deep_gemm/jit/compiler.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import functools 3 | import os 4 | import re 5 | import subprocess 6 | import uuid 7 | from torch.utils.cpp_extension import CUDA_HOME 8 | from typing import Tuple 9 | 10 | from . import interleave_ffma 11 | from .runtime import Runtime, RuntimeCache 12 | from .template import typename_map 13 | 14 | runtime_cache = RuntimeCache() 15 | 16 | 17 | def hash_to_hex(s: str) -> str: 18 | md5 = hashlib.md5() 19 | md5.update(s.encode('utf-8')) 20 | return md5.hexdigest()[0:12] 21 | 22 | 23 | @functools.lru_cache(maxsize=None) 24 | def get_jit_include_dir() -> str: 25 | return f'{os.path.dirname(os.path.abspath(__file__))}/../include' 26 | 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_deep_gemm_version() -> str: 30 | # Update include directories 31 | include_dir = f'{get_jit_include_dir()}/deep_gemm' 32 | assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' 33 | md5 = hashlib.md5() 34 | for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): 35 | with open(f'{include_dir}/{filename}', 'rb') as f: 36 | md5.update(f.read()) 37 | 38 | # Update `interleave_ffma.py` 39 | with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f: 40 | md5.update(f.read()) 41 | return md5.hexdigest()[0:12] 42 | 43 | 44 | @functools.lru_cache(maxsize=None) 45 | def get_nvcc_compiler() -> Tuple[str, str]: 46 | paths = [] 47 | if os.getenv('DG_NVCC_COMPILER'): 48 | paths.append(os.getenv('DG_NVCC_COMPILER')) 49 | paths.append(f'{CUDA_HOME}/bin/nvcc') 50 | 51 | # Try to find the first available NVCC compiler 52 | least_version_required = '12.3' 53 | version_pattern = re.compile(r'release (\d+\.\d+)') 54 | for path in paths: 55 | if os.path.exists(path): 56 | match = version_pattern.search(os.popen(f'{path} --version').read()) 57 | version = match.group(1) 58 | assert match, f'Cannot get the version of NVCC compiler {path}' 59 | assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' 60 | return path, version 61 | raise RuntimeError('Cannot find any available NVCC compiler') 62 | 63 | 64 | @functools.lru_cache(maxsize=None) 65 | def get_default_user_dir(): 66 | if 'DG_CACHE_DIR' in os.environ: 67 | path = os.getenv('DG_CACHE_DIR') 68 | os.makedirs(path, exist_ok=True) 69 | return path 70 | return os.path.expanduser('~') + '/.deep_gemm' 71 | 72 | 73 | @functools.lru_cache(maxsize=None) 74 | def get_tmp_dir(): 75 | return f'{get_default_user_dir()}/tmp' 76 | 77 | 78 | @functools.lru_cache(maxsize=None) 79 | def get_cache_dir(): 80 | return f'{get_default_user_dir()}/cache' 81 | 82 | 83 | def make_tmp_dir(): 84 | tmp_dir = get_tmp_dir() 85 | os.makedirs(tmp_dir, exist_ok=True) 86 | return tmp_dir 87 | 88 | 89 | def put(path, data, is_binary=False): 90 | # Write and do POSIX atomic replace 91 | tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' 92 | with open(tmp_file_path, 'wb' if is_binary else 'w') as f: 93 | f.write(data) 94 | os.replace(tmp_file_path, path) 95 | 96 | 97 | def build(name: str, arg_defs: tuple, code: str) -> Runtime: 98 | # Compiler flags 99 | nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', 100 | '-gencode=arch=compute_90a,code=sm_90a', 101 | '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), 102 | # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases 103 | '--diag-suppress=177,174,940'] 104 | cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] 105 | flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] 106 | include_dirs = [get_jit_include_dir()] 107 | 108 | # Build signature 109 | enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 110 | signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' 111 | name = f'kernel.{name}.{hash_to_hex(signature)}' 112 | path = f'{get_cache_dir()}/{name}' 113 | 114 | # Check runtime cache or file system hit 115 | global runtime_cache 116 | if runtime_cache[path] is not None: 117 | if os.getenv('DG_JIT_DEBUG', None): 118 | print(f'Using cached JIT runtime {name} during build') 119 | return runtime_cache[path] 120 | 121 | # Write the code 122 | os.makedirs(path, exist_ok=True) 123 | args_path = f'{path}/kernel.args' 124 | src_path = f'{path}/kernel.cu' 125 | put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) 126 | put(src_path, code) 127 | 128 | # Compile into a temporary SO file 129 | so_path = f'{path}/kernel.so' 130 | tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' 131 | 132 | # Compile 133 | command = [get_nvcc_compiler()[0], 134 | src_path, '-o', tmp_so_path, 135 | *flags, 136 | *[f'-I{d}' for d in include_dirs]] 137 | if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): 138 | print(f'Compiling JIT runtime {name} with command {command}') 139 | return_code = subprocess.check_call(command) 140 | assert return_code == 0, f'Failed to compile {src_path}' 141 | 142 | # Interleave FFMA reuse 143 | if enable_sass_opt: 144 | interleave_ffma.process(tmp_so_path) 145 | 146 | # Atomic replace SO file 147 | os.replace(tmp_so_path, so_path) 148 | 149 | # Put cache and return 150 | runtime_cache[path] = Runtime(path) 151 | return runtime_cache[path] 152 | -------------------------------------------------------------------------------- /deep_gemm/jit/interleave_ffma.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mmap 3 | import os 4 | import re 5 | import subprocess 6 | from torch.utils.cpp_extension import CUDA_HOME 7 | 8 | 9 | def run_cuobjdump(file_path): 10 | command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] 11 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 12 | assert result.returncode == 0 13 | return result.stdout 14 | 15 | 16 | def extract_ffma(sass): 17 | lines = sass.splitlines() 18 | collected = [] 19 | current = [] 20 | 21 | arch_name, func_name = 'N/A', 'N/A' 22 | skip_next_line = False 23 | for line in lines: 24 | if 'code for' in line: 25 | arch_name = line.lstrip().lstrip('code for ').rstrip() 26 | elif 'Function :' in line: 27 | func_name = line.lstrip().lstrip('Function :').rstrip() 28 | elif 'FFMA' in line: 29 | current.append(line) 30 | skip_next_line = True 31 | elif skip_next_line: 32 | current.append(line) 33 | skip_next_line = False 34 | else: 35 | if len(current) >= 16: 36 | assert len(current) % 2 == 0 37 | collected.append((f'{arch_name}::{func_name}', current)) 38 | current = [] 39 | 40 | if os.getenv('DG_PRINT_REG_REUSE', None): 41 | print(f'Found {len(collected)} FFMA segments') 42 | return collected 43 | 44 | 45 | def extract_hex_from_line(line): 46 | match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) 47 | assert match 48 | return int(match.group(1), 16) 49 | 50 | 51 | def validate(m, offset, le_bytes, num_lines): 52 | assert len(le_bytes) == num_lines // 2 53 | assert m[offset:offset + 16] == le_bytes[0] 54 | for i in range(1, num_lines // 2): 55 | if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: 56 | return False 57 | return True 58 | 59 | 60 | def parse_registers(line): 61 | line = re.sub(r'/\*.*?\*/', '', line) 62 | line = line.replace(';', '') 63 | tokens = line.strip().split(',') 64 | registers = [] 65 | for token in tokens: 66 | token = token.strip() 67 | words = token.split() 68 | for word in words: 69 | if word.startswith('R'): 70 | reg = word.split('.')[0] 71 | registers.append(reg) 72 | return registers 73 | 74 | 75 | def modify_segment(yield_every_n, m, name, ffma_lines): 76 | num_lines = len(ffma_lines) 77 | assert num_lines % 2 == 0 78 | 79 | le_bytes, new_le_bytes = [], [] 80 | reused_list = [] 81 | dst_reg_set = set() 82 | last_reused, last_dst_reg = False, '' 83 | num_changed = 0 84 | ffma_since_yield = 0 85 | for i in range(num_lines // 2): 86 | dst_reg = parse_registers(ffma_lines[i * 2])[-2] 87 | low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] 88 | low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) 89 | le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 90 | reused = (high_hex & 0x0800000000000000) != 0 91 | 92 | yielding = (high_hex & 0x0800200000000000) == 0 93 | ffma_since_yield = (not yielding) and ffma_since_yield + 1 or 0 94 | 95 | if reused: 96 | is_first_occurred = dst_reg not in dst_reg_set 97 | if (is_first_occurred or (last_reused and dst_reg == last_dst_reg)) and ffma_since_yield >= yield_every_n: 98 | # Modify the `reuse` and `yield` bits 99 | assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' 100 | high_hex ^= 0x0800200000000000 101 | reused = False 102 | num_changed += 1 103 | ffma_since_yield = 0 104 | else: 105 | reused_list.append(i) 106 | dst_reg_set.add(dst_reg) 107 | new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) 108 | last_reused, last_dst_reg = reused, dst_reg 109 | if os.getenv('DG_PRINT_REG_REUSE', None): 110 | print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') 111 | # Find the offset 112 | offsets = [] 113 | offset = m.find(le_bytes[0]) 114 | while offset != -1: 115 | offsets.append(offset) 116 | offset = m.find(le_bytes[0], offset + 1) 117 | offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) 118 | 119 | # Replace with `new_le_bytes` 120 | for offset in offsets: 121 | for i in range(num_lines // 2): 122 | m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] 123 | 124 | 125 | def process(path): 126 | # New default of 1 reuse every 4 FFMAs (previously 1 every 2 FFMAs) 127 | # DeeperGEMM is less latency sensitive than the original :) 128 | # So it's less important to prioritize the producer warps by yielding in the consumers 129 | # Increasing the amount of reuse will very slightly improve power efficiency, it's a trade-off 130 | # (in reality, this makes ~zero difference, but the previous default bothered me intellectually) 131 | yield_every_n = os.getenv('DG_FFMA_YIELD_EVERY_N_INSTRUCTIONS', 4) 132 | if os.getenv('DG_PRINT_REG_REUSE', None): 133 | print(f'Processing {path}') 134 | output = run_cuobjdump(path) 135 | segments = extract_ffma(output) 136 | with open(path, 'r+b') as f: 137 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) 138 | for segment in segments: 139 | modify_segment(yield_every_n, mm, *segment) 140 | mm.close() 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') 145 | parser.add_argument('--so', help='Path to the SO file') 146 | args = parser.parse_args() 147 | 148 | process(args.so) 149 | -------------------------------------------------------------------------------- /deep_gemm/jit/runtime.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import torch 4 | from typing import Optional 5 | 6 | from .template import map_ctype 7 | 8 | 9 | class Runtime: 10 | def __init__(self, path: str) -> None: 11 | self.path = path 12 | self.lib = None 13 | self.args = None 14 | 15 | assert self.is_path_valid(self.path) 16 | 17 | @staticmethod 18 | def is_path_valid(path: str) -> bool: 19 | # Exists and is a directory 20 | if not os.path.exists(path) or not os.path.isdir(path): 21 | return False 22 | 23 | # Contains all necessary files 24 | files = ['kernel.cu', 'kernel.args', 'kernel.so'] 25 | return all(os.path.exists(os.path.join(path, file)) for file in files) 26 | 27 | def __call__(self, *args) -> int: 28 | # Load SO file 29 | if self.lib is None or self.args is None: 30 | self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) 31 | with open(os.path.join(self.path, 'kernel.args'), 'r') as f: 32 | self.args = eval(f.read()) 33 | 34 | # Check args and launch 35 | assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' 36 | cargs = [] 37 | for arg, (name, dtype) in zip(args, self.args): 38 | if isinstance(arg, torch.Tensor): 39 | assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' 40 | else: 41 | assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' 42 | cargs.append(map_ctype(arg)) 43 | 44 | return_code = ctypes.c_int(0) 45 | self.lib.launch(*cargs, ctypes.byref(return_code)) 46 | return return_code.value 47 | 48 | 49 | class RuntimeCache: 50 | def __init__(self) -> None: 51 | self.cache = {} 52 | 53 | def __getitem__(self, path: str) -> Optional[Runtime]: 54 | # In Python runtime 55 | if path in self.cache: 56 | return self.cache[path] 57 | 58 | # Already compiled 59 | if os.path.exists(path) and Runtime.is_path_valid(path): 60 | runtime = Runtime(path) 61 | self.cache[path] = runtime 62 | return runtime 63 | return None 64 | 65 | def __setitem__(self, path, runtime) -> None: 66 | self.cache[path] = runtime 67 | -------------------------------------------------------------------------------- /deep_gemm/jit/template.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import ctypes 3 | import os 4 | import torch 5 | 6 | from typing import Any, Iterable, Dict, Tuple 7 | 8 | 9 | # Name map for Python `eval` 10 | typename_map: Dict[Any, str] = { 11 | **{t: t.__name__ for t in (bool, int, float)}, 12 | torch.int: 'torch.int', 13 | torch.float: 'torch.float', 14 | torch.bfloat16: 'torch.bfloat16', 15 | torch.float8_e4m3fn: 'torch.float8_e4m3fn', 16 | torch.uint8: 'torch.uint8', 17 | torch.cuda.Stream: 'torch.cuda.Stream', 18 | } 19 | 20 | # `ctype` map for Python casting 21 | ctype_map: Dict[Any, Any] = { 22 | **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, 23 | **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream, torch.uint8)}, 24 | } 25 | 26 | 27 | # Type map for both Python API and source code usages 28 | genc_map = { 29 | bool: ('bool', 'bool'), 30 | int: ('int', 'int'), 31 | float: ('float', 'float'), 32 | torch.int: ('void*', 'int*'), 33 | torch.float: ('void*', 'float*'), 34 | torch.bfloat16: ('void*', '__nv_bfloat16*'), 35 | torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), 36 | torch.uint8: ('void*', 'unsigned char*'), 37 | torch.cuda.Stream: ('void*', 'cudaStream_t'), 38 | } 39 | 40 | 41 | def map_ctype(value: Any) -> Any: 42 | ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)] 43 | if isinstance(value, torch.Tensor): 44 | return ctype(value.data_ptr()) 45 | if isinstance(value, torch.cuda.Stream): 46 | return ctype(value.cuda_stream) 47 | return ctype(value) 48 | 49 | 50 | def cpp_format(template: str, keys: Dict[str, Any]) -> str: 51 | # We don't use `str.format` because it's not safe for C++ {} braces 52 | new_template = copy.deepcopy(template) 53 | for key, value in keys.items(): 54 | new_template = new_template.replace(f'{{{key}}}', f'{value}') 55 | return new_template 56 | 57 | 58 | def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: 59 | # Common prefix 60 | code = '// DeepGEMM auto-generated JIT CUDA source file\n\n' 61 | 62 | # Includes 63 | preload_sys_includes = ['', '', '', ''] 64 | preload_package_includes = ['"cutlass/cutlass.h"'] 65 | 66 | assert isinstance(includes, list) or isinstance(includes, tuple) 67 | sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')]))) 68 | package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')]))) 69 | code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n' 70 | code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n' 71 | 72 | # Function signature 73 | raw = '__raw_' 74 | get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n 75 | code += f'extern "C" void launch(' 76 | code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ]) 77 | code += ') {\n' 78 | 79 | # Cast raw types 80 | code += ' // Cast raw types (if needed)\n' 81 | for arg_name, arg_type in arg_defs: 82 | if genc_map[arg_type][0] != genc_map[arg_type][1]: 83 | code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n' 84 | 85 | # Function body 86 | code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')]) 87 | 88 | # End the function 89 | code += '}\n\n' 90 | 91 | # Debug print 92 | if os.getenv('DG_JIT_DEBUG', None): 93 | print(f'Generated code:\n{code}') 94 | 95 | return code 96 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemm import gemm_fp8_fp8_bf16_nt 2 | from .m_grouped_gemm import ( 3 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, 4 | m_grouped_gemm_fp8_fp8_bf16_nt_masked 5 | ) 6 | from .utils import ( 7 | ceil_div, set_num_sms, get_num_sms, 8 | get_col_major_tma_aligned_tensor, 9 | get_m_alignment_for_contiguous_layout 10 | ) 11 | from .preprocess import ( 12 | preprocess_reorder_b, 13 | preprocess_reorder_b_grouped 14 | ) 15 | from .sideaware import ( 16 | sideaware_init, sideaware_enabled, sideaware_create_kernel, 17 | sideaware_torch_side_index, sideaware_gpu_side_index, sideaware_cpu_side_index, 18 | sideaware_info, sideaware_info_raw 19 | ) -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from .tuner import jit_tuner 5 | from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout 6 | from .sideaware import sideaware_torch_side_index, sideaware_info, sideaware_enabled 7 | 8 | # C++ code templates 9 | includes = ('"deep_gemm/fp8_gemm.cuh"', ) 10 | template = """ 11 | using namespace deep_gemm; 12 | 13 | // Templated args from Python JIT call 14 | constexpr auto N = {N}, K = {K}; 15 | constexpr auto BLOCK_M = {BLOCK_M}; 16 | constexpr auto BLOCK_N = {BLOCK_N}; 17 | constexpr auto kNumStages = {NUM_STAGES}; 18 | constexpr auto kNumUnroll = {NUM_UNROLL}; 19 | constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; 20 | constexpr auto kL2HashBits = {L2_HASH_BITS}; 21 | constexpr auto kL2Optimization = {L2_OPTIMIZATION}; 22 | 23 | // Make a templated GEMM 24 | using GemmType = Gemm; 25 | 26 | // Launch kernel 27 | auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); 28 | auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); 29 | auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); 30 | auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); 31 | auto tma_d_padded_desc = GemmType::make_3d_tma_d_desc(out, m); 32 | GemmType::run(out, rhs, rhs_scales, nullptr, 33 | m, 34 | tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_d_padded_desc, 35 | stream, num_sms, smem_size, side_index); 36 | """ 37 | 38 | 39 | def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: 40 | if num_tma_multicast == 1: 41 | return True 42 | return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 43 | 44 | 45 | def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: 46 | smem_d = block_m * block_n * 2 47 | smem_a_per_stage = block_m * block_k 48 | smem_scales_a_per_stage = block_m * 4 49 | smem_b_per_stage = block_n * block_k 50 | smem_scales_b = ceil_div(k, block_k) * 4 * 2 51 | smem_barrier = (num_stages + 2) * 8 * 2 52 | 53 | # D reuses AB pipeline stage but potentially larger due to padding to avoid bank conflicts 54 | assert(smem_d % 128 == 0) # 128 because of other alignment considerations elsewhere 55 | PADDING_N = (block_n == 64 or block_n == 96 or block_n == 128) and 16 or 0 56 | smem_d_padded = block_m * (block_n + PADDING_N) * 2 57 | smem_ab_per_stage = max(smem_a_per_stage + smem_b_per_stage, smem_d_padded) 58 | 59 | smem_size = 0 60 | # we reuse one stage of A+B to store D instead of dedicated storage 61 | # smem_size += smem_d 62 | assert smem_d <= (smem_a_per_stage + smem_b_per_stage) 63 | smem_size += num_stages * smem_ab_per_stage 64 | smem_size += num_stages * smem_scales_a_per_stage 65 | smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 66 | smem_size += smem_barrier 67 | smem_size += 4096 # scratch for tile scheduling etc. 68 | return smem_size 69 | 70 | 71 | def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, 72 | is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]: 73 | if not is_grouped_contiguous: 74 | # TODO: for some cases, smaller M block is better, add them into tuning space 75 | block_ms = (64 if m <= 64 else 128, ) 76 | else: 77 | block_ms = (get_m_alignment_for_contiguous_layout(), ) 78 | # TODO: add back other sizes if we switch to TMA tensor loads instead, multiples of 8 are slow in current approach 79 | block_ns = tuple((16, 32, 48, 64, 96, 128)) #range(16, 129, 8)) 80 | 81 | fix_wave_saturate = lambda x: num_sms if x == 0 else x 82 | get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) 83 | get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) 84 | 85 | # Decide block sizes by waves 86 | best_block_m, best_block_n = None, None 87 | for block_m in block_ms: 88 | for block_n in block_ns: 89 | success = False 90 | num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) 91 | if best_block_m is None or best_block_n is None: 92 | success = True 93 | elif num_waves < best_num_waves: 94 | success = True 95 | elif num_waves == best_num_waves: 96 | # Check last wave utilization 97 | util = get_last_wave_util(block_m, block_n) 98 | best_util = get_last_wave_util(best_block_m, best_block_n) 99 | success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) 100 | best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) 101 | assert best_block_m is not None and best_block_n is not None 102 | 103 | # Always pick the longest one 104 | # NOTES: for double B scales, the best number of stages may be reduced 105 | best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 106 | # TODO: what sizes to check? 107 | for num_stages in (9, 8, 7, 6, 5, 4) if 128 % best_block_n != 0 else (12, 10, 8, 7, 6, 5, 4): 108 | best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) 109 | if best_smem_size <= sm90_capacity: 110 | best_num_stages = num_stages 111 | break 112 | assert best_num_stages is not None 113 | 114 | # Decide the number of TMA multicast 115 | best_num_tma_multicast = 1 116 | if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and (num_groups == 1 or is_grouped_contiguous): 117 | best_num_tma_multicast = 2 118 | 119 | 120 | if False: 121 | print(f"best_block_m: {best_block_m}, best_block_n: {best_block_n}, best_num_stages: {best_num_stages}," 122 | f"best_smem_size: {best_smem_size}, best_num_tma_multicast: {best_num_tma_multicast}, m: {m}, n: {n}, k: {k}" 123 | f"==> Waves: {get_num_waves(best_block_m, best_block_n)}") 124 | 125 | 126 | return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size 127 | 128 | 129 | def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 130 | rhs: Tuple[torch.Tensor, torch.Tensor], 131 | out: torch.Tensor) -> None: 132 | """ 133 | Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 134 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 135 | RHS and RHS scaling factors are required to be transposed. 136 | The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, 137 | this function will do a transposing with a set of slow PyTorch operations. 138 | 139 | Arguments: 140 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, 141 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. 142 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. 143 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. 144 | out: the BF16 output tensor of shape `[m, n]`, representing the result. 145 | """ 146 | lhs, lhs_scales = lhs 147 | rhs, rhs_scales = rhs 148 | m, k = lhs.shape 149 | n, k_ = rhs.shape 150 | m_, n_ = out.shape 151 | 152 | assert n % 64 == 0 and k % 128 == 0 153 | 154 | # Type and shape checks 155 | assert m == m_ and n == n_ and k == k_ 156 | assert n > 0 and k > 0 157 | assert lhs_scales.shape == (m, (k + 127) // 128) 158 | assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) 159 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 160 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 161 | assert out.dtype == torch.bfloat16 162 | assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() 163 | 164 | # LHS scales must be transposed for TMA load, but not for RHS scales 165 | # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels 166 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 167 | assert rhs_scales.is_contiguous() 168 | 169 | # Do nothing if `m` is zero 170 | if m == 0: 171 | return 172 | 173 | # Auto-tuning with compilation 174 | global includes, template 175 | num_sms = get_num_sms() 176 | block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) 177 | args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size, 178 | sideaware_torch_side_index()) 179 | runtime = jit_tuner.compile_and_tune( 180 | name='gemm_fp8_fp8_bf16_nt', 181 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 182 | 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 183 | 'L2_HASH_BITS': sideaware_info()["hash"], 'L2_OPTIMIZATION': sideaware_enabled()}, 184 | space=(), 185 | includes=includes, 186 | arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), 187 | ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), 188 | ('out', torch.bfloat16), ('m', int), 189 | ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int), 190 | ('side_index', torch.uint8)), 191 | template=template, 192 | args=args 193 | ) 194 | 195 | # Run the kernel 196 | runtime(*args) 197 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/m_grouped_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from .gemm import get_best_configs 5 | from .tuner import jit_tuner 6 | from .utils import get_col_major_tma_aligned_tensor, get_num_sms 7 | from .sideaware import sideaware_torch_side_index, sideaware_info, sideaware_enabled 8 | 9 | # C++ code templates 10 | includes = ('"deep_gemm/fp8_gemm.cuh"', ) 11 | template = """ 12 | using namespace deep_gemm; 13 | 14 | // Templated args from Python JIT call 15 | constexpr auto N = {N}, K = {K}; 16 | constexpr auto BLOCK_M = {BLOCK_M}; 17 | constexpr auto BLOCK_N = {BLOCK_N}; 18 | constexpr auto kNumStages = {NUM_STAGES}; 19 | constexpr auto kNumUnroll = {NUM_UNROLL}; 20 | constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; 21 | constexpr auto kL2HashBits = {L2_HASH_BITS}; 22 | constexpr auto kL2Optimization = {L2_OPTIMIZATION}; 23 | 24 | // Make a templated grouped GEMM 25 | using GemmType = Gemm; 26 | 27 | // Launch kernel 28 | auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); 29 | auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); 30 | auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); 31 | auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); 32 | auto tma_d_padded_desc = GemmType::make_3d_tma_d_desc(out, m); 33 | GemmType::run(out, rhs, rhs_scales, grouped_layout, 34 | m, 35 | tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_d_padded_desc, 36 | stream, num_sms, smem_size, side_index); 37 | """ 38 | 39 | 40 | def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], 41 | rhs: Tuple[torch.Tensor, torch.Tensor], 42 | out: torch.Tensor, m_indices: torch.Tensor) -> None: 43 | """ 44 | Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 45 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 46 | RHS and RHS scaling factors are required to be transposed. 47 | The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, 48 | this function will do a transposing with a set of slow PyTorch operations. 49 | On the M axis, inputs are grouped into several batches, of which batch sizes aligned to 50 | `get_m_alignment_for_contiguous_layout()` (128). 51 | 52 | Arguments: 53 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, 54 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. 55 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. 56 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 57 | out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. 58 | m_indices: a tensor of shape `[m_sum]` with type `torch.int`. 59 | `m_indices[i]` records the group which the j-th row of the LHS belong to, 60 | which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. 61 | Values of `m_indices` in every-m-alignment-block must also be the same. 62 | `-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block. 63 | """ 64 | lhs, lhs_scales = lhs 65 | rhs, rhs_scales = rhs 66 | m, k = lhs.shape 67 | num_groups, n, k_ = rhs.shape 68 | m_, n_ = out.shape 69 | m__ = m_indices.numel() 70 | 71 | # Type and shape checks 72 | assert m == m_ == m__ and k == k_ and n == n_ 73 | assert lhs_scales.shape == (m, (k + 127) // 128) 74 | assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) 75 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 76 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 77 | assert out.dtype == torch.bfloat16 78 | assert m_indices.dtype == torch.int32 79 | assert lhs.is_contiguous() and rhs.is_contiguous() 80 | assert out.is_contiguous() and m_indices.is_contiguous() 81 | 82 | # LHS scales must be transposed for TMA load, but not for RHS scales 83 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 84 | assert rhs_scales.is_contiguous() 85 | 86 | # Do nothing if `m` is zero 87 | if m == 0: 88 | return 89 | 90 | # Auto-tuning with compilation 91 | global includes, template 92 | num_sms = get_num_sms() 93 | block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, 94 | is_grouped_contiguous=True) 95 | args = (lhs, lhs_scales, rhs, rhs_scales, out, 96 | m_indices, m, num_groups, 97 | torch.cuda.current_stream(), num_sms, smem_size, sideaware_torch_side_index()) 98 | runtime = jit_tuner.compile_and_tune( 99 | name='m_grouped_gemm_fp8_fp8_bf16_nt', 100 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 101 | 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous', 102 | 'L2_HASH_BITS': sideaware_info()["hash"], 'L2_OPTIMIZATION': sideaware_enabled()}, 103 | space=(), 104 | includes=includes, 105 | arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), 106 | ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), 107 | ('out', torch.bfloat16), 108 | ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), 109 | ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int), ('side_index', torch.uint8)), 110 | template=template, 111 | args=args 112 | ) 113 | 114 | # Run the kernel 115 | runtime(*args) 116 | 117 | 118 | def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], 119 | rhs: Tuple[torch.Tensor, torch.Tensor], 120 | out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: 121 | """ 122 | Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. 123 | LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. 124 | RHS and RHS scaling factors are required to be transposed. 125 | The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, 126 | this function will do a transposing with a set of slow PyTorch operations. 127 | Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch 128 | should be separately transposed. 129 | 130 | Arguments: 131 | lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, 132 | the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. 133 | rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. 134 | the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. 135 | out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. 136 | masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute 137 | in the i-th group. 138 | expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, 139 | correctly setting this value may lead to better performance. 140 | """ 141 | lhs, lhs_scales = lhs 142 | rhs, rhs_scales = rhs 143 | num_groups, m, k = lhs.shape 144 | num_groups_, n, k_ = rhs.shape 145 | num_groups__, m_, n_ = out.shape 146 | num_groups___ = masked_m.numel() 147 | 148 | # Type and shape checks 149 | assert num_groups == num_groups_ == num_groups__ == num_groups___ 150 | assert m == m_ and n == n_ and k == k_ 151 | assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 152 | assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) 153 | assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) 154 | assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 155 | assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 156 | assert out.dtype == torch.bfloat16 157 | assert masked_m.dtype == torch.int32 158 | assert lhs.is_contiguous() and rhs.is_contiguous() 159 | assert out.is_contiguous() and masked_m.is_contiguous() 160 | 161 | # LHS scales must be transposed for TMA load, but not for RHS scales 162 | lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) 163 | assert rhs_scales.is_contiguous() 164 | 165 | # Auto-tuning with compilation 166 | global includes, template 167 | num_sms = get_num_sms() 168 | block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) 169 | 170 | # Extra checks for TMA store 171 | if num_groups > 1 and m > block_m: 172 | assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' 173 | 174 | args = (lhs, lhs_scales, rhs, rhs_scales, out, 175 | masked_m, m, 176 | torch.cuda.current_stream(), num_sms, smem_size, sideaware_torch_side_index()) 177 | runtime = jit_tuner.compile_and_tune( 178 | name='m_grouped_gemm_fp8_fp8_bf16_nt', 179 | keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 180 | 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked', 181 | 'L2_HASH_BITS': sideaware_info()["hash"], 'L2_OPTIMIZATION': sideaware_enabled()}, 182 | space=(), 183 | includes=includes, 184 | arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), 185 | ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), 186 | ('out', torch.bfloat16), 187 | ('grouped_layout', torch.int32), ('m', int), 188 | ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int), ('side_index', torch.uint8)), 189 | template=template, 190 | args=args 191 | ) 192 | 193 | # Run the kernel 194 | runtime(*args) 195 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from .tuner import jit_tuner 5 | from .utils import get_num_sms, ceil_div 6 | from .sideaware import sideaware_info, sideaware_enabled 7 | 8 | # C++ code templates 9 | includes = ('"deep_gemm/reorder_b.cuh"', ) 10 | template = """ 11 | using namespace deep_gemm; 12 | 13 | // Templated args from Python JIT call 14 | constexpr auto kL2HashBits = {L2_HASH_BITS}; 15 | constexpr auto kL2Optimization = {L2_OPTIMIZATION}; 16 | 17 | // Make a templated type 18 | using ReorderType = ReorderB<128, GemmType::{GEMM_TYPE}, kL2HashBits, kL2Optimization>; 19 | 20 | // Launch kernel 21 | auto tma_b_desc = ReorderType::make_2d_tma_b_desc(b, n, k, num_groups); 22 | ReorderType::run(out_b, b, nullptr, tma_b_desc, n, k, stream, num_sms, num_groups); 23 | """ 24 | 25 | def preprocess_reorder_b(b: torch.Tensor, out_b: torch.Tensor) -> None: 26 | """ 27 | Reorder B tensor to match the smem layout and be L2 side aware(!) 28 | """ 29 | n, k = b.shape 30 | assert b.shape == out_b.shape 31 | assert b.is_contiguous() and out_b.is_contiguous() 32 | assert b.dtype == torch.float8_e4m3fn and out_b.dtype == torch.float8_e4m3fn 33 | assert k % 128 == 0 34 | 35 | # Auto-tuning with compilation 36 | global includes, template 37 | num_sms = get_num_sms() 38 | args = (b, out_b, n, k, torch.cuda.current_stream(), num_sms, 1) 39 | runtime = jit_tuner.compile_and_tune( 40 | name='reorder_b', 41 | keys={'BLOCK_K': 128, 'GEMM_TYPE': 'Normal', 42 | 'L2_HASH_BITS': sideaware_info()["hash"], 'L2_OPTIMIZATION': sideaware_enabled()}, 43 | space=(), 44 | includes=includes, 45 | arg_defs=(('b', torch.float8_e4m3fn), ('out_b', torch.float8_e4m3fn), 46 | ('n', int), ('k', int), ('stream', torch.cuda.Stream), ('num_sms', int), ('num_groups', int)), 47 | template=template, 48 | args=args 49 | ) 50 | 51 | # Run the kernel 52 | runtime(*args) 53 | 54 | def preprocess_reorder_b_grouped(b: torch.Tensor, out_b: torch.Tensor, is_masked: bool) -> None: 55 | """ 56 | Reorder B tensor to match the smem layout and be L2 side aware(!) 57 | """ 58 | num_groups, n, k = b.shape 59 | assert b.shape == out_b.shape 60 | assert b.is_contiguous() and out_b.is_contiguous() 61 | assert b.dtype == torch.float8_e4m3fn and out_b.dtype == torch.float8_e4m3fn 62 | assert k % 128 == 0 63 | 64 | # Auto-tuning with compilation 65 | global includes, template 66 | num_sms = get_num_sms() 67 | args = (b, out_b, n, k, torch.cuda.current_stream(), num_sms, num_groups) 68 | runtime = jit_tuner.compile_and_tune( 69 | name='reorder_b', 70 | keys={'BLOCK_K': 128, 'GEMM_TYPE': 'GroupedContiguous' if not is_masked else 'GroupedMasked', 71 | 'L2_HASH_BITS': sideaware_info()["hash"], 'L2_OPTIMIZATION': sideaware_enabled()}, 72 | space=(), 73 | includes=includes, 74 | arg_defs=(('b', torch.float8_e4m3fn), ('out_b', torch.float8_e4m3fn), 75 | ('n', int), ('k', int), ('stream', torch.cuda.Stream), ('num_sms', int), ('num_groups', int)), 76 | template=template, 77 | args=args 78 | ) 79 | 80 | # Run the kernel 81 | runtime(*args) 82 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/sideaware.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # CUDA L2 Side Boost wrapper originally created for DeeperGEMM 3 | # Useful as an example of how it can be integrated with PyTorch 4 | # --------------------------------------------------------------------------- 5 | # Remember to compile sideaware.so in /deep_gemm/include/l2_torch_alloc: 6 | # nvcc -arch=native -Xcompiler -fPIC -shared sideaware.cu -o sideaware.so -lcuda -lnvrtc 7 | # --------------------------------------------------------------------------- 8 | # https://github.com/ademeure/cuda-side-boost 9 | # https://github.com/ademeure/DeeperGEMM 10 | # --------------------------------------------------------------------------- 11 | import torch 12 | import ctypes 13 | 14 | _lib = None 15 | _cpu_side_index = None 16 | _gpu_side_index = None 17 | _torch_side_index = None 18 | _info = (0, 0, 0, 0, 0) 19 | _info_str = {"num_sms": 0, "side0": 0, "side1": 0, "min": 0, "hash": 0} 20 | 21 | # ----------------------------------------------------------------------------- 22 | # Externally visible API (+torch.ops.sideaware.[memcpy/one_to_one/elementwise]) 23 | # ----------------------------------------------------------------------------- 24 | def sideaware_enabled(): 25 | return _lib and 1 or 0 26 | 27 | # Create custom elementwise kernels (returns id for sideaware_elementwise) 28 | def sideaware_create_kernel(header_code: bytes) -> int: 29 | return _lib.sideaware_create_kernel(header_code) 30 | 31 | # GPU SM side metadata (which SM is on which side, SMs per side, etc...) 32 | def sideaware_torch_side_index(): 33 | global _torch_side_index 34 | if _torch_side_index is None: 35 | _torch_side_index = torch.zeros(1, dtype=torch.uint8, device="cuda") 36 | return _torch_side_index # torch.uint8 tensor of size num_sms 37 | def sideaware_gpu_side_index(): 38 | return _gpu_side_index # gpu buffer of size num_sms 39 | def sideaware_cpu_side_index(): 40 | return _cpu_side_index # cpu buffer of size num_sms 41 | def sideaware_info(): 42 | return _info_str # {"num_sms", "side0", "side1", "min", "hash"} 43 | def sideaware_info_raw(): 44 | return _info # (num_sms, side0, side1, min, hash) 45 | 46 | 47 | # Load sideaware.so library both directly and through CUDAPluggableAllocator 48 | def sideaware_init(path = 'sideaware.so'): 49 | sideaware_alloc = torch.cuda.memory.CUDAPluggableAllocator(path, 'sideaware_malloc_auto', 'sideaware_free_auto') 50 | torch.cuda.memory.change_current_allocator(sideaware_alloc) 51 | 52 | global _lib 53 | _lib = ctypes.CDLL(path) 54 | 55 | # Define C-style function signatures 56 | _lib.sideaware_create_kernel.argtypes = [ctypes.c_char_p] 57 | _lib.sideaware_create_kernel.restype = ctypes.c_int 58 | 59 | _lib.sideaware_sm_side_summary.argtypes = [] 60 | _lib.sideaware_sm_side_summary.restype = ctypes.POINTER(ctypes.c_int * 5) 61 | 62 | _lib.sideaware_fill_side_index.argtypes = [ctypes.c_void_p] 63 | _lib.sideaware_fill_side_index.restype = None 64 | _lib.sideaware_gpu_side_index.argtypes = [] 65 | _lib.sideaware_gpu_side_index.restype = ctypes.c_void_p 66 | _lib.sideaware_cpu_side_index.argtypes = [] 67 | _lib.sideaware_cpu_side_index.restype = ctypes.c_void_p 68 | 69 | _lib.sideaware_memcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p] 70 | _lib.sideaware_memcpy.restype = None 71 | _lib.sideaware_one_to_one.argtypes = [ctypes.c_int, ctypes.c_size_t, # kernel_id, num_bytes 72 | ctypes.c_void_p, ctypes.c_void_p, # out0, in0 73 | ctypes.c_int, ctypes.c_void_p] # device, stream 74 | _lib.sideaware_one_to_one.restype = None 75 | _lib.sideaware_elementwise.argtypes = [ctypes.c_int, ctypes.c_size_t, # kernel_id, num_bytes 76 | ctypes.c_void_p, ctypes.c_void_p, # out0, out1 77 | ctypes.c_void_p, ctypes.c_void_p, # out2, out3 78 | ctypes.c_void_p, ctypes.c_void_p, # in0, in1 79 | ctypes.c_void_p, ctypes.c_void_p, # in2, in3 80 | ctypes.c_void_p, ctypes.c_void_p, # sideband_ptr, sideband_value 81 | ctypes.c_int, ctypes.c_int, # parallel_chunks, forced_sm_per_side 82 | ctypes.c_int, ctypes.c_void_p] # device, stream 83 | _lib.sideaware_elementwise.restype = None 84 | 85 | # Define PyTorch custom operations for memcpy/one_to_one/elementwise 86 | def direct_register_custom_op(op_lib, op_name, op_func, mutates_args): 87 | schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) 88 | op_lib.define(op_name + schema_str) 89 | op_lib.impl(op_name, op_func, "CUDA") 90 | 91 | sideaware_lib = torch.library.Library("sideaware", "FRAGMENT") 92 | direct_register_custom_op(sideaware_lib, "memcpy", sideaware_memcpy, mutates_args=(["dst"])) 93 | direct_register_custom_op(sideaware_lib, "one_to_one", sideaware_one_to_one, mutates_args=(["dst"])) 94 | direct_register_custom_op(sideaware_lib, "elementwise", sideaware_elementwise, 95 | mutates_args=(["out0", "out1", "out2", "out3", "sideband_tensor"])) 96 | 97 | # Initialize sideaware metadata 98 | global _info, _info_str, _torch_side_index, _gpu_side_index, _cpu_side_index 99 | _info = tuple(_lib.sideaware_sm_side_summary().contents) 100 | _info_str = { "num_sms": _info[0], "side0": _info[1], "side1": _info[2], "min": _info[3], "hash": _info[4] } 101 | 102 | _torch_side_index = torch.zeros(_info_str["num_sms"], dtype=torch.uint8, device="cuda") 103 | _lib.sideaware_fill_side_index(_torch_side_index.data_ptr()) 104 | _gpu_side_index = _lib.sideaware_gpu_side_index() 105 | _cpu_side_index = _lib.sideaware_cpu_side_index() 106 | 107 | # Print metadata (shows we are done with initialization) 108 | print(f"L2 Side Aware metadata: {_info_str}") 109 | 110 | # ----------------------------------------------------------------------------- 111 | # Exposed via torch.ops.sideaware.[memcpy/one_to_one/elementwise]() only 112 | # ----------------------------------------------------------------------------- 113 | 114 | # Sideaware memcpy (i.e. "default kernel" when no custom kernel is provided via sideaware_create_kernel) 115 | def sideaware_memcpy(dst: torch.Tensor, src: torch.Tensor) -> None: 116 | # Validate inputs 117 | assert dst.device.type == "cuda" and src.device.type == "cuda", "Both tensors must be on CUDA" 118 | assert dst.dtype == src.dtype, "Source and destination must have the same dtype" 119 | assert dst.numel() >= src.numel(), "Destination tensor must be at least as large as source" 120 | 121 | # Get pointers and size 122 | dst_ptr = dst.data_ptr() 123 | src_ptr = src.data_ptr() 124 | num_bytes = src.numel() * src.element_size() 125 | 126 | # Make sure src and dst are contiguous and aligned 127 | assert src.is_contiguous(), "src must be contiguous" 128 | assert dst.is_contiguous(), "dst must be contiguous" 129 | assert (dst_ptr % 16 == 0) and (src_ptr % 16 == 0), "dst and src must be 16-byte aligned" 130 | 131 | device, stream = torch.cuda.current_device(), torch.cuda.current_stream() 132 | _lib.sideaware_memcpy(dst_ptr, src_ptr, num_bytes, device, stream) 133 | 134 | # Sideaware single-input / single-output elementwise API (simple version) 135 | def sideaware_one_to_one(kernel_id: int, dst: torch.Tensor, src: torch.Tensor) -> None: 136 | # Validate inputs 137 | src_bytes = dst is not None and dst.numel() * dst.element_size() or 0 138 | dst_bytes = dst is not None and dst.numel() * dst.element_size() or 0 139 | num_bytes = max(src_bytes, dst_bytes) 140 | assert num_bytes > 0 141 | 142 | # Make sure src and dst are contiguous and aligned 143 | dst_ptr = dst is not None and dst.data_ptr() or 0 144 | src_ptr = src_bytes and src.data_ptr() or 0 145 | assert src is None or (src.is_contiguous() and src.device.type == "cuda"), "src must be contiguous" 146 | assert dst is None or (dst.is_contiguous() and dst.device.type == "cuda"), "dst must be contiguous" 147 | assert dst_ptr % 16 == 0 and src_ptr % 16 == 0, "dst and src must be 16-byte aligned" 148 | 149 | device, stream = torch.cuda.current_device(), torch.cuda.current_stream() 150 | _lib.sideaware_one_to_one(kernel_id, num_bytes, dst_ptr, src_ptr, device, stream) 151 | 152 | # Sideaware multi-input / multi-output elementwise API (advanced version) 153 | def sideaware_elementwise(kernel_id: int, 154 | out0: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor, out3: torch.Tensor, 155 | in0: torch.Tensor, in1: torch.Tensor, in2: torch.Tensor, in3: torch.Tensor, 156 | sideband_tensor: torch.Tensor = None, sideband_value: int = 0, 157 | parallel_chunks: int = 0, forced_sm_per_side: int = 0) -> None: 158 | # Validate inputs 159 | src_bytes = out0 is not None and out0.numel() * out0.element_size() or 0 160 | dst_bytes = out0 is not None and out0.numel() * out0.element_size() or 0 161 | num_bytes = max(src_bytes, dst_bytes) 162 | assert num_bytes > 0 163 | 164 | # Make sure src and dst are contiguous and aligned 165 | out0_ptr = out0 is not None and out0.data_ptr() or 0 166 | out1_ptr = out1 is not None and out1.data_ptr() or 0 167 | out2_ptr = out2 is not None and out2.data_ptr() or 0 168 | out3_ptr = out3 is not None and out3.data_ptr() or 0 169 | in0_ptr = in0 is not None and in0.data_ptr() or 0 170 | in1_ptr = in1 is not None and in1.data_ptr() or 0 171 | in2_ptr = in2 is not None and in2.data_ptr() or 0 172 | in3_ptr = in3 is not None and in3.data_ptr() or 0 173 | sideband_ptr = sideband_tensor is not None and sideband_tensor.data_ptr() or 0 174 | 175 | assert in0 is None or (in0.is_contiguous() and in0.device.type == "cuda"), "in0 must be contiguous" 176 | assert in1 is None or (in1.is_contiguous() and in1.device.type == "cuda"), "in1 must be contiguous" 177 | assert in2 is None or (in2.is_contiguous() and in2.device.type == "cuda"), "in2 must be contiguous" 178 | assert in3 is None or (in3.is_contiguous() and in3.device.type == "cuda"), "in3 must be contiguous" 179 | assert out0 is None or (out0.is_contiguous() and out0.device.type == "cuda"), "out0 must be contiguous" 180 | assert out1 is None or (out1.is_contiguous() and out1.device.type == "cuda"), "out1 must be contiguous" 181 | assert out2 is None or (out2.is_contiguous() and out2.device.type == "cuda"), "out2 must be contiguous" 182 | assert out3 is None or (out3.is_contiguous() and out3.device.type == "cuda"), "out3 must be contiguous" 183 | assert out0_ptr % 16 == 0 and out1_ptr % 16 == 0 and out2_ptr % 16 == 0 and out3_ptr % 16 == 0, "16B alignment" 184 | assert in0_ptr % 16 == 0 and in1_ptr % 16 == 0 and in2_ptr % 16 == 0 and in3_ptr % 16 == 0, "16B alignment" 185 | 186 | device, stream = torch.cuda.current_device(), torch.cuda.current_stream() 187 | _lib.sideaware_elementwise(kernel_id, num_bytes, 188 | out0_ptr, out1_ptr, out2_ptr, out3_ptr, 189 | in0_ptr, in1_ptr, in2_ptr, in3_ptr, 190 | sideband_ptr, sideband_value, parallel_chunks, forced_sm_per_side, device, stream) 191 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/tuner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import torch 4 | from typing import Any, Dict 5 | 6 | from ..jit import build, cpp_format, generate, Runtime 7 | from .utils import ceil_div 8 | 9 | class JITTuner: 10 | def __init__(self) -> None: 11 | self.tuned = {} 12 | 13 | def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, 14 | includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: 15 | # NOTES: we always assume the space and template will not change 16 | # We also assume the GPU device will not be changed 17 | # NOTES: the function must have no accumulated side effects 18 | keys = {k: keys[k] for k in sorted(keys.keys())} 19 | signature = (name, f'{keys}') 20 | if signature in self.tuned: 21 | if os.getenv('DG_JIT_DEBUG', None): 22 | print(f'Using cached JIT kernel {name} with keys {keys}') 23 | return self.tuned[signature] 24 | 25 | if os.getenv('DG_JIT_DEBUG', None): 26 | print(f'Auto-tuning JIT kernel {name} with keys {keys}') 27 | 28 | # TODO: dynamic/automatic tuning of unroll factor 29 | # TODO: manual unrolling by using template.py to copy the code multiple times? 30 | # TODO: handle tail better, because BLOCK_K=8192 means 31 iterations which is prime :( 31 | # TODO: rewrite all this to be helpful with DOUBLE_PUMP mode (and/or copy-paste instead of unroll) 32 | if not "NUM_UNROLL" in keys and "K" in keys: 33 | # Find largest divisor of loop iteration count that's no greater than max_unroll 34 | max_unroll = 15 35 | loop_iterations = max(1, ceil_div(keys["K"], 256) - 1) # 1 per 2 BLOCK_K 36 | num_unroll = min(loop_iterations, max_unroll) 37 | while loop_iterations % num_unroll != 0: 38 | num_unroll -= 1 39 | if (loop_iterations >= 16 and num_unroll <= 4): 40 | num_unroll = 8 41 | keys["NUM_UNROLL"] = num_unroll 42 | 43 | assert signature not in self.tuned 44 | assert args is not None 45 | space = (dict(), ) if len(space) == 0 else space 46 | 47 | kernels = [] 48 | for tuned_keys in space: 49 | assert isinstance(tuned_keys, dict) 50 | full_keys = copy.deepcopy(keys) 51 | full_keys.update(tuned_keys) 52 | code = generate(includes, arg_defs, cpp_format(template, full_keys)) 53 | 54 | # Illegal build must raise errors 55 | kernels.append((build(name, arg_defs, code), tuned_keys)) 56 | 57 | best_runtime, best_time, best_keys = None, None, None 58 | for runtime, tuned_keys in kernels: 59 | if len(space) > 1: 60 | # Check kernel validity 61 | return_code = runtime(*args) 62 | if return_code != 0: 63 | # Pass illegal kernels, e.g. insufficient shared memory capacity 64 | if os.getenv('DG_JIT_DEBUG', None): 65 | print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') 66 | continue 67 | 68 | # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels 69 | start_event = torch.cuda.Event(enable_timing=True) 70 | end_event = torch.cuda.Event(enable_timing=True) 71 | torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() 72 | torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') 73 | start_event.record() 74 | for i in range(20): 75 | assert runtime(*args) == 0 76 | end_event.record() 77 | end_event.synchronize() 78 | elapsed_time = start_event.elapsed_time(end_event) 79 | else: 80 | elapsed_time = 0 81 | 82 | # Compare if better 83 | if best_time is None or elapsed_time < best_time: 84 | best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys 85 | if os.getenv('DG_JIT_DEBUG', None): 86 | print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') 87 | assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' 88 | 89 | # Cache the best runtime and return 90 | if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): 91 | print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') 92 | self.tuned[signature] = best_runtime 93 | return best_runtime 94 | 95 | 96 | jit_tuner = JITTuner() 97 | -------------------------------------------------------------------------------- /deep_gemm/jit_kernels/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | _num_sms = None 4 | 5 | 6 | def set_num_sms(num_sms: int) -> None: 7 | """ 8 | Set the maximum SM count for all GEMM kernels to use. 9 | 10 | Arguments: 11 | num_sms: the desired maximum SM count for all GEMM kernels to use. 12 | """ 13 | global _num_sms 14 | assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count 15 | _num_sms = num_sms 16 | 17 | 18 | def get_num_sms() -> int: 19 | """ 20 | Get the current maximum limit of SM count for all GEMM kernels to use. 21 | If the count is never specified, the function will return the number of device SMs. 22 | 23 | Returns: 24 | Current maximum limit of SM count for all GEMM kernels to use. 25 | """ 26 | global _num_sms 27 | if _num_sms is None: 28 | _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count 29 | return _num_sms 30 | 31 | 32 | def ceil_div(x: int, y: int) -> int: 33 | """ 34 | Perform ceiling division of two integers. 35 | 36 | Args: 37 | x: the dividend. 38 | y: the divisor. 39 | 40 | Returns: 41 | The result of the ceiling division. 42 | """ 43 | return (x + y - 1) // y 44 | 45 | 46 | def get_m_alignment_for_contiguous_layout(): 47 | """ 48 | When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. 49 | Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well 50 | with GEMM block shape. 51 | 52 | Returns: 53 | Group-level alignment requirement for grouped contiguous layout, which is always 128. 54 | """ 55 | return 128 56 | 57 | 58 | def get_tma_aligned_size(x: int, element_size: int) -> int: 59 | """ 60 | Global memory address of TMA must be 16-byte aligned. 61 | Since we use column-major layout for the LHS scaling tensor, 62 | the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. 63 | 64 | Arguments: 65 | x: original M-axis shape of the LHS scaling tensor. 66 | element_size: element size of the LHS scaling tensor. 67 | 68 | Returns: 69 | M-axis shape of the LHS scaling tensor after padding. 70 | """ 71 | tma_alignment_bytes = 16 72 | assert tma_alignment_bytes % element_size == 0 73 | alignment = tma_alignment_bytes // element_size 74 | return ceil_div(x, alignment) * alignment 75 | 76 | 77 | def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. 80 | If the input tensor is already column-major layout and 16-byte aligned along the M axis 81 | (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. 82 | 83 | Arguments: 84 | x: usually the LHS scaling tensor in GEMM. 85 | 86 | Returns: 87 | The LHS scaling tensor of TMA-aligned transposed format. 88 | """ 89 | # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA 90 | assert x.dim() in (2, 3) 91 | remove_dim = False 92 | 93 | # TODO: this was fixed separately in the base repo, need to merge recent changes 94 | aligned_m = get_tma_aligned_size(x.shape[x.dim()-2], x.element_size()) 95 | if x.dim() == 2: 96 | if x.stride(0) == 1 and x.stride(1) == aligned_m: 97 | return x # fast path when already transposed (previously resulted in a memcpy) 98 | else: 99 | x, remove_dim = x.unsqueeze(0), True 100 | 101 | b, m, n = x.shape 102 | 103 | # The last kernel gives a column-major TMA aligned layout 104 | if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: 105 | return x.squeeze(0) if remove_dim else x 106 | 107 | # Normal layout requires transposing 108 | aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) 109 | aligned_x[:, :m, :] = x 110 | aligned_x = aligned_x[:, :m, :] 111 | return aligned_x.squeeze(0) if remove_dim else aligned_x 112 | -------------------------------------------------------------------------------- /deep_gemm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def bench(fn, num_warmups: int = 5, num_tests: int = 10, 9 | high_precision: bool = False): 10 | # Flush L2 cache with 256 MB data 11 | torch.cuda.synchronize() 12 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') 13 | cache.zero_() 14 | 15 | # Warmup 16 | for _ in range(num_warmups): 17 | fn() 18 | 19 | # Add a large kernel to eliminate the CPU launch overhead 20 | if high_precision: 21 | x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 22 | y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 23 | x @ y 24 | 25 | # Testing 26 | start_event = torch.cuda.Event(enable_timing=True) 27 | end_event = torch.cuda.Event(enable_timing=True) 28 | start_event.record() 29 | for i in range(num_tests): 30 | fn() 31 | end_event.record() 32 | torch.cuda.synchronize() 33 | 34 | return start_event.elapsed_time(end_event) / num_tests 35 | 36 | 37 | class empty_suppress: 38 | def __enter__(self): 39 | return self 40 | 41 | def __exit__(self, *_): 42 | pass 43 | 44 | 45 | class suppress_stdout_stderr: 46 | def __enter__(self): 47 | self.outnull_file = open(os.devnull, 'w') 48 | self.errnull_file = open(os.devnull, 'w') 49 | 50 | self.old_stdout_fileno_undup = sys.stdout.fileno() 51 | self.old_stderr_fileno_undup = sys.stderr.fileno() 52 | 53 | self.old_stdout_fileno = os.dup(sys.stdout.fileno()) 54 | self.old_stderr_fileno = os.dup(sys.stderr.fileno()) 55 | 56 | self.old_stdout = sys.stdout 57 | self.old_stderr = sys.stderr 58 | 59 | os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) 60 | os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) 61 | 62 | sys.stdout = self.outnull_file 63 | sys.stderr = self.errnull_file 64 | return self 65 | 66 | def __exit__(self, *_): 67 | sys.stdout = self.old_stdout 68 | sys.stderr = self.old_stderr 69 | 70 | os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) 71 | os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 72 | 73 | os.close(self.old_stdout_fileno) 74 | os.close(self.old_stderr_fileno) 75 | 76 | self.outnull_file.close() 77 | self.errnull_file.close() 78 | 79 | 80 | def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, 81 | trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): 82 | # Conflict with Nsight Systems 83 | using_nsys = os.environ.get('DG_NSYS_PROFILING', False) 84 | 85 | # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle 86 | # this avoid thermal throttling while keeping DVFS at maximum clocks (slight perf gain vs sleep / more consistent) 87 | sleep_between_tests = 0.0 88 | flush_l2_size = int(8e9 // 4) 89 | l2_read_flusher = None 90 | if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False): 91 | flush_l2 = False 92 | elif os.environ.get('DG_BENCH_L2_FLUSH_READ_ONLY', False): 93 | l2_read_flusher = torch.ones(flush_l2_size, dtype=torch.int, device='cuda') 94 | if os.environ.get('DG_BENCH_POWER_LIMITED', False): 95 | # if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time 96 | # and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2) 97 | num_tests = 2000 98 | sleep_between_tests = 0.0 99 | flush_l2_size = int(80e6 // 4) 100 | sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False) 101 | if sleep_val: 102 | try: 103 | sleep_between_tests = float(sleep_val) 104 | except ValueError: 105 | pass # Keep default 106 | 107 | # For some auto-tuning kernels with prints 108 | fn() 109 | 110 | # Profile 111 | suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress 112 | with suppress(): 113 | schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None 114 | profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() 115 | with profiler: 116 | for i in range(2): 117 | # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead 118 | if barrier_comm_profiling: 119 | lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 120 | rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 121 | lhs @ rhs 122 | dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) 123 | for _ in range(num_tests): 124 | if sleep_between_tests > 0.0: 125 | time.sleep(sleep_between_tests) 126 | if flush_l2: 127 | if l2_read_flusher is not None: 128 | if l2_read_flusher.sum() == 0: 129 | print("Impossible!") 130 | #torch.cuda.synchronize() 131 | else: 132 | torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() 133 | fn() 134 | 135 | if not using_nsys: 136 | profiler.step() 137 | 138 | # Return 1 if using Nsight Systems 139 | if using_nsys: 140 | return 1 141 | 142 | # Parse the profiling table 143 | assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) 144 | is_tupled = isinstance(kernel_names, tuple) 145 | prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') 146 | kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names 147 | assert all([isinstance(name, str) for name in kernel_names]) 148 | for name in kernel_names: 149 | assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' 150 | 151 | # Save chrome traces 152 | if trace_path is not None: 153 | profiler.export_chrome_trace(trace_path) 154 | 155 | # Return average kernel times 156 | units = {'ms': 1e3, 'us': 1e6} 157 | kernel_times = [] 158 | for name in kernel_names: 159 | for line in prof_lines: 160 | if name in line: 161 | time_str = line.split()[-2] 162 | for unit, scale in units.items(): 163 | if unit in time_str: 164 | kernel_times.append(float(time_str.replace(unit, '')) / scale) 165 | break 166 | break 167 | return tuple(kernel_times) if is_tupled else kernel_times[0] 168 | 169 | 170 | def calc_diff(x, y): 171 | x, y = x.double(), y.double() 172 | denominator = (x * x + y * y).sum() 173 | sim = 2 * (x * y).sum() / denominator 174 | return 1 - sim 175 | 176 | 177 | def count_bytes(tensors): 178 | total = 0 179 | for t in tensors: 180 | if isinstance(t, tuple): 181 | total += count_bytes(t) 182 | else: 183 | total += t.numel() * t.element_size() 184 | return total 185 | -------------------------------------------------------------------------------- /figures/design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademeure/DeeperGEMM/4b05985655deee6112cf4c87fa51709ef699b012/figures/design.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | import shutil 4 | import subprocess 5 | from setuptools.command.build_py import build_py 6 | from setuptools.command.develop import develop 7 | 8 | current_dir = os.path.dirname(os.path.realpath(__file__)) 9 | jit_include_dirs = ('deep_gemm/include/deep_gemm', ) 10 | third_party_include_dirs = ( 11 | 'third-party/cutlass/include/cute', 12 | 'third-party/cutlass/include/cutlass', 13 | ) 14 | 15 | 16 | class PostDevelopCommand(develop): 17 | def run(self): 18 | develop.run(self) 19 | self.make_jit_include_symlinks() 20 | 21 | @staticmethod 22 | def make_jit_include_symlinks(): 23 | # Make symbolic links of third-party include directories 24 | for d in third_party_include_dirs: 25 | dirname = d.split('/')[-1] 26 | src_dir = f'{current_dir}/{d}' 27 | dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' 28 | assert os.path.exists(src_dir) 29 | if os.path.exists(dst_dir): 30 | assert os.path.islink(dst_dir) 31 | os.unlink(dst_dir) 32 | os.symlink(src_dir, dst_dir, target_is_directory=True) 33 | 34 | 35 | class CustomBuildPy(build_py): 36 | def run(self): 37 | # First, prepare the include directories 38 | self.prepare_includes() 39 | 40 | # Then run the regular build 41 | build_py.run(self) 42 | 43 | def prepare_includes(self): 44 | # Create temporary build directory instead of modifying package directory 45 | build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') 46 | os.makedirs(build_include_dir, exist_ok=True) 47 | 48 | # Copy third-party includes to the build directory 49 | for d in third_party_include_dirs: 50 | dirname = d.split('/')[-1] 51 | src_dir = os.path.join(current_dir, d) 52 | dst_dir = os.path.join(build_include_dir, dirname) 53 | 54 | # Remove existing directory if it exists 55 | if os.path.exists(dst_dir): 56 | shutil.rmtree(dst_dir) 57 | 58 | # Copy the directory 59 | shutil.copytree(src_dir, dst_dir) 60 | 61 | 62 | if __name__ == '__main__': 63 | # noinspection PyBroadException 64 | try: 65 | cmd = ['git', 'rev-parse', '--short', 'HEAD'] 66 | revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() 67 | except: 68 | revision = '' 69 | 70 | setuptools.setup( 71 | name='deep_gemm', 72 | version='1.0.0' + revision, 73 | packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], 74 | package_data={ 75 | 'deep_gemm': [ 76 | 'include/deep_gemm/**/*', 77 | 'include/cute/**/*', 78 | 'include/cutlass/**/*', 79 | ] 80 | }, 81 | cmdclass={ 82 | 'develop': PostDevelopCommand, 83 | 'build_py': CustomBuildPy, 84 | }, 85 | ) 86 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from typing import Tuple 4 | import os 5 | 6 | import deep_gemm 7 | from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor 8 | 9 | 10 | def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 11 | assert x.dim() == 2 and x.size(1) % 128 == 0 12 | m, n = x.shape 13 | x_view = x.view(m, -1, 128) 14 | x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) 15 | return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) 16 | 17 | 18 | def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 19 | assert x.dim() == 2 20 | m, n = x.shape 21 | x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) 22 | x_padded[:m, :n] = x 23 | x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) 24 | x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) 25 | x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) 26 | return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) 27 | 28 | 29 | def construct(m: int, k: int, n: int) -> \ 30 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 31 | x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) 32 | y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) 33 | out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) 34 | ref_out = x @ y.t() 35 | 36 | x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) 37 | 38 | # If L2 side optimization 39 | y_fp8_reordered = torch.empty_like(y_fp8[0]) 40 | deep_gemm.preprocess_reorder_b(y_fp8[0], y_fp8_reordered) 41 | y_fp8_output = (y_fp8_reordered, y_fp8[1]) 42 | 43 | # Transpose earlier so that the testing will not trigger transposing kernels 44 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 45 | return x_fp8, y_fp8_output, out, ref_out 46 | 47 | 48 | def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \ 49 | Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 50 | x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) 51 | y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) 52 | out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) 53 | ref_out = torch.einsum('gmk,gnk->gmn', x, y) 54 | 55 | assert m % 4 == 0, f'TMA alignment error: {m}' 56 | x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) 57 | y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) 58 | for i in range(num_groups): 59 | x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) 60 | y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) 61 | 62 | # If L2 side optimization 63 | y_fp8_0_reordered = torch.empty_like(y_fp8[0]) 64 | deep_gemm.preprocess_reorder_b_grouped(y_fp8[0], y_fp8_0_reordered, is_masked) 65 | y_fp8_output = (y_fp8_0_reordered, y_fp8[1]) 66 | 67 | # For non-masked input, we must merge the group and M dims 68 | if not is_masked: 69 | x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) 70 | out, ref_out = out.view(-1, n), ref_out.view(-1, n) 71 | 72 | # Transpose earlier so that the testing will not trigger transposing kernels 73 | x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) 74 | return x_fp8, y_fp8_output, out, ref_out 75 | 76 | 77 | def test_gemm() -> None: 78 | print('Testing GEMM:') 79 | for m in (64, 128, 4096): 80 | for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: 81 | x_fp8, y_fp8, out, ref_out = construct(m, k, n) 82 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 83 | diff = calc_diff(out, ref_out) 84 | # Increased tolerance for 256-wide scaling (since we didn't fix the scaling factors yet) 85 | assert diff < 0.05, f'{m=}, {k=}, {n=}, {diff:.5f}' 86 | 87 | # noinspection PyShadowingNames 88 | def test_func(): 89 | deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) 90 | 91 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 92 | print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' 93 | f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' 94 | f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') 95 | print() 96 | 97 | 98 | def test_m_grouped_gemm_contiguous() -> None: 99 | print('Testing grouped contiguous GEMM:') 100 | 101 | for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): 102 | # TODO: make a stronger test 103 | x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) 104 | m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) 105 | m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) 106 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 107 | diff = calc_diff(out, ref_out) 108 | assert diff < 0.05, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}' 109 | 110 | x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) 111 | m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) 112 | m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) 113 | 114 | # noinspection PyShadowingNames 115 | def test_func(): 116 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) 117 | 118 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 119 | print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 120 | f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' 121 | f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') 122 | print() 123 | 124 | 125 | def test_m_grouped_gemm_masked() -> None: 126 | print('Testing grouped masked GEMM:') 127 | 128 | for num_groups, m in ((1, 1024), (2, 512), (4, 256)): 129 | for k, n in ((7168, 4096), (2048, 7168), ): 130 | # Test correctness 131 | masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) 132 | for i in range(10): 133 | x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) 134 | masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) 135 | for j in range(num_groups): 136 | masked_m[j] = random.choice(masked_m_candidates) 137 | expected_m = min(int(masked_m.float().mean()) + 1, m) 138 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) 139 | for j in range(num_groups): 140 | diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) 141 | assert diff < 0.05, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' 142 | 143 | x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) 144 | masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m 145 | 146 | # noinspection PyShadowingNames 147 | def test_func(): 148 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m) 149 | 150 | # Test performance with fixed shapes 151 | t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) 152 | print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' 153 | f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' 154 | f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') 155 | print() 156 | 157 | 158 | if __name__ == '__main__': 159 | # Use our custom L2 Side Aware memory allocator (optimization is disabled if not initialized) 160 | if not os.environ.get('DG_DISABLE_L2_OPTIMIZATION', False): 161 | deep_gemm.sideaware_init('./deep_gemm/include/l2_torch_alloc/sideaware.so') 162 | 163 | torch.backends.cuda.matmul.allow_tf32 = True 164 | torch.backends.cudnn.allow_tf32 = True 165 | torch.manual_seed(0) 166 | random.seed(0) 167 | 168 | print('Library path:') 169 | print(f' > {deep_gemm.__path__}\n') 170 | 171 | test_gemm() 172 | test_m_grouped_gemm_contiguous() 173 | test_m_grouped_gemm_masked() 174 | -------------------------------------------------------------------------------- /tests/test_jit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import Any 4 | 5 | from deep_gemm import jit 6 | 7 | 8 | class Capture: 9 | def __init__(self) -> None: 10 | self.read_fd = None 11 | self.write_fd = None 12 | self.saved_stdout = None 13 | self.captured = None 14 | 15 | def __enter__(self) -> Any: 16 | self.read_fd, self.write_fd = os.pipe() 17 | self.saved_stdout = os.dup(1) 18 | os.dup2(self.write_fd, 1) 19 | return self 20 | 21 | def __exit__(self, exc_type, exc_val, exc_tb) -> None: 22 | os.dup2(self.saved_stdout, 1) 23 | os.close(self.write_fd) 24 | with os.fdopen(self.read_fd, 'r') as f: 25 | self.captured = f.read() 26 | 27 | def capture(self) -> str: 28 | return self.captured 29 | 30 | 31 | if __name__ == '__main__': 32 | # Runtime 33 | print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') 34 | 35 | # Templates 36 | print('Generated code:') 37 | args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), 38 | ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) 39 | body = "\n" 40 | body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' 41 | body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' 42 | body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' 43 | body += 'std::cout << reinterpret_cast(out) << std::endl;\n' 44 | body += 'std::cout << enable_double_streams << std::endl;\n' 45 | body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' 46 | code = jit.generate((), args, body) 47 | print(code) 48 | 49 | # Build 50 | print('Building ...') 51 | func = jit.build('test_func', args, code) 52 | 53 | # Test correctness 54 | print('Running ...') 55 | fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') 56 | fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') 57 | bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') 58 | with Capture() as capture: 59 | assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 60 | output = capture.capture() 61 | ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' 62 | assert output == ref_output, f'{output=}, {ref_output=}' 63 | 64 | print('JIT test passed') 65 | --------------------------------------------------------------------------------