├── images ├── perf.png ├── pingpong.png ├── tma_swizzle.png ├── wgmma_layout.png ├── stmatrix_script.png ├── tma_swizzle_old.png └── stmatrix_layout_funny.png ├── Makefile ├── test.py ├── layout.py ├── setup.py ├── l2.py ├── stmatrix.py ├── LICENSE ├── benchmark.sh ├── op.cpp ├── denoise-h100.sh ├── .clang-format ├── maxreg.cu ├── main.cu ├── README.md ├── benchmark.py ├── gemm.cu ├── stmatrix.cu └── pingpong.cu /images/perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/perf.png -------------------------------------------------------------------------------- /images/pingpong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/pingpong.png -------------------------------------------------------------------------------- /images/tma_swizzle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/tma_swizzle.png -------------------------------------------------------------------------------- /images/wgmma_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/wgmma_layout.png -------------------------------------------------------------------------------- /images/stmatrix_script.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/stmatrix_script.png -------------------------------------------------------------------------------- /images/tma_swizzle_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/tma_swizzle_old.png -------------------------------------------------------------------------------- /images/stmatrix_layout_funny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bertmaher/simplegemm/HEAD/images/stmatrix_layout_funny.png -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NVCC_FLAGS = -std=c++17 -O3 -DNDEBUG -w 2 | NVCC_FLAGS += --expt-relaxed-constexpr --expt-extended-lambda -Xcompiler=-fPIE -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing -arch=sm_90a 3 | NVCC_LDFLAGS = -lcublas -lcuda # --keep # -lineinfo 4 | 5 | gemm: main.cu gemm.cu pingpong.cu stmatrix.cu 6 | nvcc $(NVCC_FLAGS) $(NVCC_LDFLAGS) $< -o $@ 7 | 8 | maxreg: maxreg.cu 9 | nvcc $(NVCC_FLAGS) $(NVCC_LDFLAGS) $^ -o $@ 10 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.ops.load_library("gemm.so") 4 | 5 | torch.set_default_device("cuda") 6 | 7 | m, n, k = 128, 256, 64 8 | 9 | a = torch.arange(m * k).reshape(m, k).bfloat16() 10 | b = torch.eye(k, n).bfloat16().T.contiguous().T 11 | c = torch.ops.gemm.pingpong(a, b) 12 | cref = torch.mm(a, b) 13 | 14 | print(b.size(), b.stride()) 15 | print(a) 16 | print(b) 17 | print(cref) 18 | print(c) 19 | 20 | print(torch.allclose(c, cref, atol=0, rtol=0)) 21 | -------------------------------------------------------------------------------- /layout.py: -------------------------------------------------------------------------------- 1 | for tid in range(0, 128): 2 | warp = tid // 32 3 | lane = tid % 32 4 | row_base = warp * 16 5 | row_off = lane // 4 6 | row = row_base + row_off 7 | 8 | col = tid % 4 * 2 9 | rcs = [] 10 | for n in range(0, 256, 8): 11 | rcs.extend([ 12 | (row, n + col), 13 | (row, n + col + 1), 14 | (row + 8, n + col), 15 | (row + 8, n + col + 1), 16 | ]) 17 | print(f"{tid:3d}: " + " ".join(f"({r:3d}, {c:3d})" for r, c in rcs)) 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | 6 | setup( 7 | name="gemm", 8 | ext_modules=[ 9 | CUDAExtension( 10 | "gemm", 11 | [ 12 | "op.cpp", 13 | "gemm.cu", 14 | "pingpong.cu", 15 | "stmatrix.cu", 16 | ], 17 | extra_compile_args=["-lineinfo"], 18 | extra_link_args=["-lcuda"], 19 | ) 20 | ], 21 | cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, 22 | ) 23 | -------------------------------------------------------------------------------- /l2.py: -------------------------------------------------------------------------------- 1 | sms = 132 2 | m_blocks = 66 3 | n_blocks = 72 4 | 5 | all_blocks = set() 6 | for m in range(m_blocks): 7 | for n in range(n_blocks): 8 | all_blocks.add((m, n)) 9 | 10 | my_blocks = set() 11 | for sm in range(0, 132): 12 | for bid in range(sm, m_blocks * n_blocks, sms): 13 | 14 | """ 15 | m = sm // 2 #bid % m_blocks 16 | n = bid // m_blocks + (sm % 2) 17 | """ 18 | 19 | m = (bid // 2) % m_blocks 20 | n = (bid // 2) // m_blocks * 2 + bid % 2 21 | 22 | print(f"{sm} {m} {n}") 23 | my_blocks.add((m, n)) 24 | if (m, n) not in all_blocks: 25 | print(f"egad: {m}, {n}") 26 | 27 | for m in range(m_blocks): 28 | for n in range(n_blocks): 29 | if (m, n) not in my_blocks: 30 | print(f"oh no: {m}, {n}") 31 | -------------------------------------------------------------------------------- /stmatrix.py: -------------------------------------------------------------------------------- 1 | addrs = set() 2 | 3 | INST_M = 64 4 | for tid in range(128): 5 | warp = tid // 32 6 | lane = tid % 32 7 | base_x1_row = warp * 16 8 | base_x4_row = base_x1_row + (lane // 8 % 2) * 8 9 | base_x4_col = lane % 8 + lane // 16 * 8 10 | 11 | base_addr = base_x4_row + INST_M * base_x4_col 12 | bank = base_addr // 2 % 32 13 | 14 | padded_base_addr = base_x4_row + (INST_M + 8) * base_x4_col 15 | padded_bank = padded_base_addr // 2 % 32 16 | 17 | swizzle_addr = base_addr ^ ((lane & 7) << 3) 18 | swizzle_bank = swizzle_addr // 2 % 32 19 | 20 | addrs.add(swizzle_addr) 21 | print(f"{tid:3d}: ({base_x4_row:3d}, {base_x4_col:3d}): {base_addr:5d} {bank:5d} | {padded_base_addr:5d} {padded_bank:5d} | {swizzle_addr:5d} {swizzle_bank:5d}") 22 | 23 | print(len(addrs)) 24 | print(sorted(addrs)) 25 | for x, y in zip(sorted(addrs), range(0, 1024, 8)): 26 | print(x, y, "" if x == y else "FAIL") 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Bertrand Maher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TORCHINDUCTOR_CACHE_DIR=/tmp/pingpong_matmul_experiments_20250310 4 | export TORCHINDUCTOR_CUTLASS_DIR=$HOME/local/cutlass 5 | export TORCHINDUCTOR_CUTLASS_ALLOWLIST='128x128x64_1x1x1.*pingpong_epi_tma' 6 | export TORCHINDUCTOR_CUTLASS_DENYLIST='stream_k' 7 | export TORCHINDUCTOR_CUTLASS_INSTANTIATION_LEVEL=0201 8 | export USE_IR_LOC=ttgir 9 | 10 | DATE=$(date +%s) 11 | export TRITON_DUMP_DIR=$(realpath "dump.$DATE") 12 | 13 | RUN_COMMAND="python benchmark.py" 14 | 15 | if false; then 16 | export TRITON_OVERRIDE_DIR=$(realpath "override.$DATE") 17 | 18 | echo $TRITON_DUMP_DIR 19 | echo $TRITON_OVERRIDE_DIR 20 | 21 | rm -rf $TRITON_DUMP_DIR $TRITON_OVERRIDE_DIR 22 | 23 | TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 ./denoise-h100.sh $RUN_COMMAND 24 | cp -r $TRITON_DUMP_DIR $TRITON_OVERRIDE_DIR 25 | TTGIR_PATH=$(find $TRITON_OVERRIDE_DIR -name 'matmul_persistent_tma_ws_pingpong_kernel.ttgir') 26 | find $TRITON_OVERRIDE_DIR -type f -delete 27 | cp matmul_persistent_tma_ws_pingpong_kernel.ttgir $TTGIR_PATH 28 | fi 29 | 30 | export BENCHMARK_CUTLASS=1 31 | TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 ./denoise-h100.sh $RUN_COMMAND 32 | -------------------------------------------------------------------------------- /op.cpp: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | #include "ATen/ATen.h" // @manual 4 | #include "torch/extension.h" // @manual 5 | 6 | #include 7 | 8 | using bf16 = __nv_bfloat16; 9 | 10 | void run_gemm(void* A, void* B, void* C, int M, int N, int K); 11 | void run_pingpong(void* A, void* B, void* C, int M, int N, int K); 12 | void run_stmatrix_gemm(void* A, void* B, void* C, int M, int N, int K); 13 | 14 | at::Tensor gemm(at::Tensor a, at::Tensor b) { 15 | // a (m x k), b (k x n) 16 | auto c = a.new_empty({b.size(1), a.size(0)}).transpose(0, 1); 17 | run_gemm( 18 | a.data_ptr(), 19 | b.data_ptr(), 20 | c.data_ptr(), 21 | a.size(0), 22 | b.size(1), 23 | a.size(1)); 24 | return c; 25 | } 26 | 27 | at::Tensor pingpong(at::Tensor a, at::Tensor b) { 28 | // a (m x k), b (k x n) 29 | auto c = a.new_empty({b.size(1), a.size(0)}).transpose(0, 1); 30 | run_pingpong( 31 | a.data_ptr(), 32 | b.data_ptr(), 33 | c.data_ptr(), 34 | a.size(0), 35 | b.size(1), 36 | a.size(1)); 37 | return c; 38 | } 39 | 40 | at::Tensor stmatrix_gemm(at::Tensor a, at::Tensor b) { 41 | // a (m x k), b (k x n) 42 | auto c = a.new_empty({b.size(1), a.size(0)}).transpose(0, 1); 43 | run_stmatrix_gemm( 44 | a.data_ptr(), 45 | b.data_ptr(), 46 | c.data_ptr(), 47 | a.size(0), 48 | b.size(1), 49 | a.size(1)); 50 | return c; 51 | } 52 | 53 | TORCH_LIBRARY(gemm, m) { 54 | m.def("gemm", &gemm); 55 | m.def("pingpong", &pingpong); 56 | m.def("stmatrix_gemm", &stmatrix_gemm); 57 | } 58 | -------------------------------------------------------------------------------- /denoise-h100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 3 | 4 | # There's a whole presentation about stable benchmarking here: 5 | # https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9956-best-practices-when-benchmarking-cuda-applications_V2.pdf 6 | 7 | export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=4}" 8 | 9 | # 1335, 1980 10 | # Lock GPU clocks 11 | ( 12 | sudo nvidia-smi -i "$CUDA_VISIBLE_DEVICES" -pm 1 # persistent mode 13 | sudo nvidia-smi --power-limit=650 -i "$CUDA_VISIBLE_DEVICES" # lock to 650 W 14 | sudo nvidia-smi -lgc 1980 -i "$CUDA_VISIBLE_DEVICES" # lock to 1980 MHz. The max on H100 is 1980 MHz, but power throttling can still occur 15 | ) >/dev/null 16 | 17 | # TODO: On my devgpu, device 6 is apparently attached to NUMA node 3. How did 18 | # I discover this? 19 | # 20 | # `nvidia-smi -i 6 -pm 1` prints the PCI bus ID (00000000:C6:00.0) 21 | # 22 | # You can also get this from `nvidia-smi -x -q` and looking for minor_number 23 | # and pci_bus_id 24 | # 25 | # Then, `cat /sys/bus/pci/devices/0000:c6:00.0/numa_node` prints 3 26 | # is it always the case that device N is on numa node N/2? :shrug: 27 | # 28 | # Maybe automate this process or figure out if it always holds? 29 | # 30 | # ... Or you can just `nvidia-smi topo -mp` and it will just print out exactly 31 | # what you want, like this: 32 | 33 | # GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 mlx5_0 mlx5_1 mlx5_2 mlx5_3 CPU Affinity NUMA Affinity 34 | # GPU0 X PXB SYS SYS SYS SYS SYS SYS NODE SYS SYS SYS 0-23,96-119 0 35 | # GPU6 SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS NODE 72-95,168-191 3 36 | 37 | numactl -m 0 -c 0 "$@" 38 | 39 | # Unlock GPU clock 40 | ( 41 | sudo nvidia-smi -rgc -i "$CUDA_VISIBLE_DEVICES" 42 | sudo nvidia-smi --power-limit=500 -i "$CUDA_VISIBLE_DEVICES" # lock to 500 W 43 | ) >/dev/null 44 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: AlwaysBreak 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: true 7 | AlignOperands: false 8 | AlignTrailingComments: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: true 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BraceWrapping: 21 | AfterClass: false 22 | AfterControlStatement: false 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterObjCDeclaration: false 27 | AfterStruct: false 28 | AfterUnion: false 29 | BeforeCatch: false 30 | BeforeElse: false 31 | IndentBraces: false 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeBraces: Attach 34 | BreakBeforeTernaryOperators: true 35 | BreakConstructorInitializersBeforeComma: false 36 | BreakAfterJavaFieldAnnotations: false 37 | BreakStringLiterals: false 38 | ColumnLimit: 80 39 | CommentPragmas: '^ IWYU pragma:' 40 | CompactNamespaces: false 41 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 42 | ConstructorInitializerIndentWidth: 4 43 | ContinuationIndentWidth: 4 44 | Cpp11BracedListStyle: true 45 | DerivePointerAlignment: false 46 | DisableFormat: false 47 | ForEachMacros: 48 | - FOR_EACH_RANGE 49 | - FOR_EACH 50 | IncludeCategories: 51 | - Regex: '^<.*\.h(pp)?>' 52 | Priority: 1 53 | - Regex: '^<.*' 54 | Priority: 2 55 | - Regex: '.*' 56 | Priority: 3 57 | IndentCaseLabels: true 58 | IndentWidth: 2 59 | IndentWrappedFunctionNames: false 60 | KeepEmptyLinesAtTheStartOfBlocks: false 61 | MacroBlockBegin: '' 62 | MacroBlockEnd: '' 63 | Macros: 64 | - >- 65 | PyObject_HEAD_INIT(type)={ 66 | /* this is not exactly match with PyObject_HEAD_INIT in Python source code 67 | * but it is enough for clang-format */ 68 | { 0xFFFFFFFF }, 69 | (type) 70 | }, 71 | - >- 72 | PyVarObject_HEAD_INIT(type, size)={ 73 | { 74 | /* manually expand PyObject_HEAD_INIT(type) above 75 | * because clang-format do not support recursive expansion */ 76 | { 0xFFFFFFFF }, 77 | (type) 78 | }, 79 | (size) 80 | }, 81 | MaxEmptyLinesToKeep: 1 82 | NamespaceIndentation: None 83 | PenaltyBreakBeforeFirstCallParameter: 1 84 | PenaltyBreakComment: 300 85 | PenaltyBreakFirstLessLess: 120 86 | PenaltyBreakString: 1000 87 | PenaltyExcessCharacter: 1000000 88 | PenaltyReturnTypeOnItsOwnLine: 2000000 89 | PointerAlignment: Left 90 | ReflowComments: true 91 | SortIncludes: true 92 | SpaceAfterCStyleCast: false 93 | SpaceBeforeAssignmentOperators: true 94 | SpaceBeforeParens: ControlStatements 95 | SpaceInEmptyParentheses: false 96 | SpacesBeforeTrailingComments: 1 97 | SpacesInAngles: false 98 | SpacesInContainerLiterals: true 99 | SpacesInCStyleCastParentheses: false 100 | SpacesInParentheses: false 101 | SpacesInSquareBrackets: false 102 | Standard: c++17 103 | StatementMacros: 104 | - C10_DEFINE_bool 105 | - C10_DEFINE_int 106 | - C10_DEFINE_int32 107 | - C10_DEFINE_int64 108 | - C10_DEFINE_string 109 | - C10_DEFINE_REGISTRY_WITHOUT_WARNING 110 | - C10_REGISTER_CREATOR 111 | - DEFINE_BINARY 112 | - PyObject_HEAD 113 | - PyObject_VAR_HEAD 114 | - PyException_HEAD 115 | - TORCH_DECLARE_bool 116 | 117 | TabWidth: 8 118 | UseTab: Never 119 | --- 120 | Language: ObjC 121 | ColumnLimit: 120 122 | AlignAfterOpenBracket: Align 123 | ObjCBlockIndentWidth: 2 124 | ObjCSpaceAfterProperty: false 125 | ObjCSpaceBeforeProtocolList: false 126 | ... 127 | -------------------------------------------------------------------------------- /maxreg.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using barrier = cuda::barrier; 19 | 20 | void checkCudaErrors(cudaError_t error, const char* file, int line) { 21 | if (error != cudaSuccess) { 22 | fprintf( 23 | stderr, 24 | "CUDA error at %s:%d: %s\n", 25 | file, 26 | line, 27 | cudaGetErrorString(error)); 28 | exit(EXIT_FAILURE); 29 | } 30 | } 31 | 32 | #define check(err) checkCudaErrors(err, __FILE__, __LINE__) 33 | 34 | 35 | template 36 | __device__ void warpgroup_reg_alloc() { 37 | asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount)); 38 | } 39 | 40 | __device__ static void __forceinline__ 41 | init_barrier(uint64_t* bar, int thread_count, int transaction_count) { 42 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 43 | asm volatile( 44 | "mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(bar_ptr), 45 | "r"(thread_count + transaction_count) : "memory"); 46 | } 47 | 48 | __device__ static void __forceinline__ wait_barrier(uint64_t* bar, int phase) { 49 | uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); 50 | asm volatile( 51 | "{\n" 52 | ".reg .pred P1;\n" 53 | "LAB_WAIT:\n" 54 | "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" 55 | "@P1 bra.uni DONE;\n" 56 | "bra.uni LAB_WAIT;\n" 57 | "DONE:\n" 58 | "}\n" ::"r"(mbar_ptr), 59 | "r"(phase):"memory"); 60 | } 61 | 62 | __device__ static void __forceinline__ 63 | arrive_barrier(uint64_t* bar, int count) { 64 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 65 | asm volatile( 66 | "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" ::"r"( 67 | bar_ptr), 68 | "r"(count) 69 | : "memory"); 70 | } 71 | 72 | __global__ __launch_bounds__(384) void dummy1() { 73 | __shared__ __align__(8) uint64_t bar, bar2; 74 | //__shared__ barrier bar; 75 | int tid = threadIdx.x; 76 | int wg = tid / 128; 77 | int wgtid = tid % 128; 78 | 79 | if (tid == 0) { 80 | init_barrier(&bar, 0, 2); 81 | init_barrier(&bar2, 0, 1); 82 | } 83 | __syncthreads(); 84 | if (wg == 0) { 85 | int phase = 0; 86 | if (wgtid == 0) { 87 | //printf("producer %d\n", wg); 88 | wait_barrier(&bar, phase); 89 | //printf("producer %d 1 done\n", wg); 90 | wait_barrier(&bar, phase ^ 1); 91 | //arrive_barrier(&bar2, 1); 92 | //wait_barrier(&bar, phase ^ 1); 93 | //printf("producer %d 2 done\n", wg); 94 | } 95 | } else { 96 | int phase = 0; 97 | if (wgtid == 0) { 98 | //printf("consumer %d\n", wg); 99 | arrive_barrier(&bar, 1); 100 | } 101 | //asm volatile("bar.sync %0, 128;" :: "r"(wg) : "memory"); 102 | if (wgtid < 2) { 103 | arrive_barrier(&bar, 1); 104 | //wait_barrier(&bar2, 1); 105 | //arrive_barrier(&bar, 1); 106 | //printf("consumer %d done\n", wg); 107 | } 108 | } 109 | } 110 | 111 | __global__ __launch_bounds__(384) void dummy() { 112 | __shared__ barrier bar; 113 | int tid = threadIdx.x; 114 | int wg = tid / 128; 115 | int wgtid = tid % 128; 116 | 117 | if (tid == 0) { 118 | init(&bar, 3); 119 | } 120 | __syncthreads(); 121 | 122 | if (wg == 0) { 123 | int phase = 0; 124 | asm volatile("{\n//test 1\n}\n" ::: "memory"); 125 | if (wgtid == 0) { 126 | bar.wait(bar.arrive()); 127 | bar.wait(bar.arrive()); 128 | } 129 | asm volatile("{\n//test 2\n}\n" ::: "memory"); 130 | } else { 131 | int phase = 0; 132 | asm volatile("{\n//test 3\n}\n" ::: "memory"); 133 | if (wgtid == 0) { 134 | bar.arrive(); 135 | bar.arrive(); 136 | } 137 | asm volatile("{\n//test 4\n}\n" ::: "memory"); 138 | } 139 | } 140 | 141 | __global__ __launch_bounds__(384) void dummy() { 142 | __shared__ barrier bar; 143 | int tid = threadIdx.x; 144 | int wg = tid / 128; 145 | int wgtid = tid % 128; 146 | 147 | if (tid == 0) { 148 | init(&bar, 3); 149 | } 150 | __syncthreads(); 151 | 152 | if (wg == 0) { 153 | int phase = 0; 154 | asm volatile("{\n//test 1\n}\n" ::: "memory"); 155 | if (wgtid == 0) { 156 | bar.wait(bar.arrive()); 157 | bar.wait(bar.arrive()); 158 | } 159 | asm volatile("{\n//test 2\n}\n" ::: "memory"); 160 | } else { 161 | int phase = 0; 162 | asm volatile("{\n//test 3\n}\n" ::: "memory"); 163 | if (wgtid == 0) { 164 | bar.arrive(); 165 | bar.arrive(); 166 | } 167 | asm volatile("{\n//test 4\n}\n" ::: "memory"); 168 | } 169 | } 170 | 171 | int main() { 172 | fprintf(stderr, "GO!\n"); 173 | dummy<<<1, 384>>>(); 174 | check(cudaDeviceSynchronize()); 175 | fprintf(stderr, "DONE!\n"); 176 | return 0; 177 | } 178 | -------------------------------------------------------------------------------- /main.cu: -------------------------------------------------------------------------------- 1 | #include "pingpong.cu" 2 | //#include "stmatrix.cu" 3 | 4 | __global__ void testFill(bf16* X, int M, int N, int parity) { 5 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 6 | int m_idx = idx % M; 7 | int n_idx = idx / M; 8 | if (m_idx >= M || n_idx >= N) 9 | return; 10 | if (parity < 0) { 11 | X[idx] = (m_idx == n_idx) ? 1.0 : 0.0; 12 | } else { 13 | X[idx] = idx; 14 | } 15 | 16 | // int v = (idx % 8 - 4); 17 | // //v = (v >= 0) ? v + 1 : v; 18 | // //X[idx] = (bf16)(v * parity); 19 | // X[idx] = (float)(clock() % 8) / 8.0 - 0.5; 20 | } 21 | 22 | cublasHandle_t cublas_handle; 23 | void runCublasGemmBF16(int M, int N, int K, bf16* A, bf16* B, bf16* C) { 24 | float alpha = 1, beta = 0; 25 | // C(column major) = A(row major) * B(column major) 26 | cublasStatus_t status = cublasGemmEx( 27 | cublas_handle, 28 | CUBLAS_OP_T, 29 | CUBLAS_OP_N, 30 | M, 31 | N, 32 | K, 33 | &alpha, 34 | A, 35 | CUDA_R_16BF, 36 | K, 37 | B, 38 | CUDA_R_16BF, 39 | K, 40 | &beta, 41 | C, 42 | CUDA_R_16BF, 43 | M, 44 | CUBLAS_COMPUTE_32F, 45 | CUBLAS_GEMM_DEFAULT); 46 | 47 | if (status != CUBLAS_STATUS_SUCCESS) { 48 | fprintf(stderr, "CUBLAS error: %d\n", status); 49 | exit(EXIT_FAILURE); 50 | } 51 | } 52 | 53 | __global__ __launch_bounds__( 54 | 1024) void naive_gemm(bf16* A, bf16* B, bf16* C, int M, int N, int K) { 55 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (idx < M * N) { 57 | int m_idx = idx % M; 58 | int n_idx = idx / M; 59 | float sum = 0.0; 60 | for (int k = 0; k < K; k++) { 61 | sum += __bfloat162float(A[m_idx * K + k]) * 62 | __bfloat162float(B[k + n_idx * K]); 63 | } 64 | C[m_idx + n_idx * M] = __float2bfloat16(sum); 65 | } 66 | } 67 | 68 | void run_naive_gemm(bf16* A, bf16* B, bf16* C, int M, int N, int K) { 69 | naive_gemm<<>>(A, B, C, M, N, K); 70 | } 71 | 72 | template 73 | void randomize_matrix(Gen& generator, bf16 *hM, bf16 *dM, int N) { 74 | std::normal_distribution distribution(0, 1); 75 | for (int i = 0; i < N; i++) { 76 | hM[i] = distribution(generator); 77 | } 78 | check(cudaMemcpy(dM, hM, sizeof(bf16) * N, cudaMemcpyHostToDevice)); 79 | } 80 | 81 | void arange(bf16 *hM, bf16* dM, int M, int N) { 82 | for (int m = 0; m < M; m++) { 83 | for (int n = 0; n < N; n++) { 84 | hM[m * N + n] = m * N + n; 85 | } 86 | } 87 | check(cudaMemcpy(dM, hM, sizeof(bf16) * M * N, cudaMemcpyHostToDevice)); 88 | } 89 | 90 | void identity(bf16 *hM, bf16* dM, int M, int N) { 91 | for (int m = 0; m < M; m++) { 92 | for (int n = 0; n < N; n++) { 93 | hM[m + n * M] = (m == n) ? 1.0f : 0.0f; 94 | } 95 | } 96 | check(cudaMemcpy(dM, hM, sizeof(bf16) * M * N, cudaMemcpyHostToDevice)); 97 | } 98 | 99 | void print_matrix(bf16* hM, bf16* dM, int M, int N, bool rowmajor) { 100 | check(cudaMemcpy(hM, dM, sizeof(bf16) * M * N, cudaMemcpyDeviceToHost)); 101 | auto strideM = rowmajor ? N : 1; 102 | auto strideN = rowmajor ? 1 : M; 103 | for (int i = 0; i < min(M, 128); i++) { 104 | for (int j = 0; j < min(N, 128); j++) { 105 | printf(" %6.2f", __bfloat162float(hM[i * strideM + j * strideN])); 106 | } 107 | printf(" ...\n"); 108 | } 109 | printf("...\n\n"); 110 | } 111 | 112 | int main() { 113 | // int m = 6 * 11 * 128; 114 | // int n = 6 * 12 * 128; 115 | // int k = 512; 116 | 117 | // m = k = 8; 118 | // n = 16; 119 | 120 | int m = 6 * 11 * 128; 121 | int n = 3 * 12 * 256; 122 | int k = 1024; 123 | 124 | m = 6 * 11 * 128; 125 | n = 6 * 12 * 128; 126 | k = 640; 127 | 128 | // m = 8 * 128; 129 | // n = 8 * 256; 130 | // k = 64; 131 | //m = n = k = 8192; 132 | //int max = 8192; 133 | int max = 16384; 134 | int numel = max * max; 135 | 136 | // Allocate matrices 137 | __nv_bfloat16* A; 138 | __nv_bfloat16* B; 139 | __nv_bfloat16* C; 140 | __nv_bfloat16* Cref; 141 | 142 | check(cudaMalloc((void**)&A, sizeof(bf16) * max * max)); 143 | check(cudaMalloc((void**)&B, sizeof(bf16) * max * max)); 144 | check(cudaMalloc((void**)&C, sizeof(bf16) * max * max)); 145 | check(cudaMalloc((void**)&Cref, sizeof(bf16) * max * max)); 146 | 147 | bf16* hM = (bf16*)malloc(sizeof(bf16) * numel); 148 | 149 | // Fill with test data. 150 | //testFill<<>>(A, m, k, 1); 151 | //testFill<<>>(B, k, n, -1); 152 | std::default_random_engine gen(1337); 153 | randomize_matrix(gen, hM, A, numel); 154 | randomize_matrix(gen, hM, B, numel); 155 | //arange(hM, A, m, k); 156 | //identity(hM, B, k, n); 157 | //randomize_matrix(gen, hM, C, numel); 158 | check(cudaMemset(C, 0, sizeof(bf16) * numel)); 159 | check(cudaGetLastError()); 160 | 161 | // Generate cuBLAS reference. 162 | cublasCreate(&cublas_handle); 163 | runCublasGemmBF16(m, n, k, A, B, Cref); 164 | 165 | // Run test kernel. 166 | printf("about to run gemm\n"); 167 | run_pingpong(A, B, C, m, n, k); 168 | 169 | // Print a slab of matrix for sanity. 170 | printf("A:\n"); print_matrix(hM, A, m, k, true); 171 | printf("B:\n"); print_matrix(hM, B, k, n, false); 172 | printf("C:\n"); print_matrix(hM, C, m, n, false); 173 | printf("Cref:\n"); print_matrix(hM, Cref, m, n, false); 174 | 175 | // Test against cuBLAS reference. 176 | bf16* hostC = nullptr; 177 | bf16* hostCref = nullptr; 178 | if (true) { 179 | hostC = (bf16*)malloc(sizeof(bf16) * m * n); 180 | hostCref = (bf16*)malloc(sizeof(bf16) * m * n); 181 | 182 | check(cudaMemcpy(hostC, C, sizeof(bf16) * m * n, cudaMemcpyDeviceToHost)); 183 | check(cudaMemcpy( 184 | hostCref, Cref, sizeof(bf16) * m * n, cudaMemcpyDeviceToHost)); 185 | 186 | for (int i = 0; i < m * n; i++) { 187 | float cv = __bfloat162float(hostC[i]); 188 | float crefv = __bfloat162float(hostCref[i]); 189 | if (std::abs(cv - crefv) > 1e-5) { 190 | fprintf( 191 | stderr, 192 | "Failed tolerance check at idx (%d, %d), C=%f, Cref=%f\n", 193 | i / n, i % n, 194 | cv, 195 | crefv); 196 | exit(EXIT_FAILURE); 197 | } 198 | } 199 | } 200 | 201 | auto benchmark = false; 202 | if (benchmark) { 203 | // Benchmark test kernel. 204 | cudaEvent_t start; 205 | cudaEvent_t stop; 206 | check(cudaEventCreate(&start)); 207 | check(cudaEventCreate(&stop)); 208 | 209 | int repeat_times = 1000; 210 | float ms = 0.0f; 211 | check(cudaEventRecord(start)); 212 | for (int j = 0; j < repeat_times; j++) { 213 | run_pingpong(A, B, C, m, n, k); 214 | } 215 | check(cudaEventRecord(stop)); 216 | check(cudaEventSynchronize(start)); 217 | check(cudaEventSynchronize(stop)); 218 | check(cudaEventElapsedTime(&ms, start, stop)); 219 | 220 | long flops = 2ll * m * n * k * repeat_times; 221 | printf("TFLOPS: %.1f\n", flops / ms * 1e-9); 222 | } 223 | 224 | // Free resources. 225 | cudaFree(A); 226 | cudaFree(B); 227 | cudaFree(C); 228 | cudaFree(Cref); 229 | free(hM); 230 | free(hostC); 231 | free(hostCref); 232 | return 0; 233 | } 234 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pingpong GEMM from scratch 2 | 3 | I've been really excited to learn the lowest-level details of GPU matrix 4 | multiplication recently, so I was really inspired to read Pranjal Shankhdhar's 5 | fantastic blog post [Outperforming cuBLAS on 6 | H100](https://cudaforfun.substack.com/p/outperforming-cublas-on-h100-a-worklog), 7 | which implements a fast gemm from first principles in CUDA, and actually 8 | outperforms cuBLAS. 9 | 10 | In a similar vein, I wanted to understand the 11 | [pingpong](https://github.com/NVIDIA/cutlass/blob/main/media/docs/efficient_gemm.md#hopper-warp-specialization) 12 | gemm algorithm in detail. So, I used https://github.com/pranjalssh/fast.cu as 13 | a starting point, and wrote this kernel to see if I could match CUTLASS's 14 | pingpong implementation myself, using hand-written CUDA. 15 | 16 | You can run a quick check of the kernel with: 17 | ``` 18 | make gemm && ./gemm 19 | ``` 20 | 21 | And run a sweep through a bunch of different shapes with: 22 | ``` 23 | # You need a nightly build of pytorch 24 | # pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 25 | TORCH_CUDA_ARCH_LIST=9.0a+PTX python setup.py develop && python benchmark.py 26 | ``` 27 | 28 | # Experimental Setup 29 | 30 | My goal was to implement a pingpong gemm kernel that matches the performance of 31 | an equivalent CUTLASS kernel, as a way to check that I understand all the 32 | optimizations that go into that kernel. Specifically I was targeting 33 | `cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x128x64_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma`; 34 | since I suspect some folks may not be familiar with CUTLASS's naming 35 | conventions, I'll break down the features I'm targeting: 36 | 37 | * A warp-specialized, pingpong "mainloop", which trades the tensor core 38 | pipeline between two consumer warpgroups to hide the latency of the epilogue. 39 | This [blog post by Less Wright and Adnan 40 | Hoque](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/) is a very 41 | good overview of this scheduling technique; I've reproduced their explanatory 42 | graphic below. 43 | * A tile size of M=128, N=128, K=64, which is the largest tile size that can be 44 | supported by this technique use an equivalent pipeline depth (which happens 45 | to be 6 stages) 46 | * ***DISABLE*** threadblock clusters! I'm putting this in bold and italics 47 | because I want to be super clear that I'm not comparing against the very best 48 | CUTLASS kernel. I avoided cluster launch for a sort of lame reason; my day 49 | job is working on the Triton compiler, which doesn't seem to support cluster 50 | launch (or pingpong). So, I wanted to focus on one thing at a time. 51 | * use TMA to write back the results in the epilogue (which requires storing the 52 | matmul result from registers to shared memory before initiating the TMA 53 | transfer) 54 | 55 | ![Ping-pong gemm illustration](images/pingpong.png "Ping-pong gemm illustration") 56 | 57 | I focused on shapes with M = 6 * 11 * 128 = 8448 and N = 6 * 12 * 128 = 9216. 58 | The reason for this oddly specific shape is that our kernel uses 128x128 output 59 | tiles; the H100 has 132 SMs (which is 11 * 12), so this shape will cause the 60 | kernel to iterate completely 36 times on each SM. This ensures that we're 61 | avoiding tile and wave quantization effects and focusing just on the main-loop 62 | code quality. From there I swept over K values from 64 to 2048. 63 | 64 | # Results 65 | 66 | I think I more or less succeeded. My kernel slightly outperforms CUTLASS for 67 | small K and large K, and slightly underperfoms at medium K, with a geomean 68 | improvement of 2%. Both implementations are faster than the default cuBLAS 69 | 12.4 implementation (yes, I know I should switch to 12.8), and all those are 70 | faster than my Triton pingpong implementation. 71 | 72 | It's ridiculously tempting to keep digging in to figure out what's going on in 73 | the middle (and it's actually quite easy to do these experiments in CUDA! I 74 | tried three ideas in the course of writing this article!) but I think I've 75 | basically hit my goal. 76 | 77 | ![Results](images/perf.png "Results") 78 | 79 | I also gained a much more detailed understanding of the pingpong 80 | algorithm itself! There are a few aspects in particular that aren't 81 | totally obvious at first glance: 82 | 83 | * The input pipeline reads A and B for each consumer separately, serially, but 84 | into the same queue. This allows the best pipelining and load balancing, but 85 | requires a bit of subtlety in the barrier management. 86 | * The output tile is a 128x128 buffer in shared memory that is used by both 87 | consumers. There isn't enough space to have a separate output tile for each 88 | consumer, so we need to share a tile and use a mutex. This means the mutex 89 | isn't just for managing contention amongst the tensor cores (as I originally 90 | thought) but is actually necessary for correctness. 91 | 92 | # Walkthrough 93 | 94 | For those interested in the details of kernel optimization, I thought I'd walk 95 | through my general approach. This isn't as detailed as the "Outperforming 96 | cuBLAS" worklog, unfortunately, because (1) I didn't organize my experiments 97 | cleanly enough, and (2) I don't have time to go back and redo them to pretend I 98 | did 😛. But in broad strokes I think the explanation is useful, at least as a 99 | guide to some matmul implementation tricks: 100 | 101 | ## Stage 1: Reproducing prior work 102 | 103 | I started with [one of the 104 | kernels](https://github.com/pranjalssh/fast.cu/blob/main/examples/matmul/matmul_7.cuh) 105 | from "Outperforming cuBLAS." Specifically I used Kernel 7, which includes 106 | using TMA for loads and WGMMA, as well as warp specialization, and large tiles 107 | (128x256). But it's before he added threadblock clusters (Triton doesn't 108 | support cluster launch) and Hilbert curves for L2 locality (smart but a bit out 109 | of codegen scope). I didn't include TMA stores at first, since that came after 110 | clusters; instead I started with direct use of st.global to write results to 111 | gmem. 112 | 113 | I implemented everything "from scratch", but following the above source as a 114 | very detailed recipe. Why'd I bother? I'm a little embarrassed by this part 115 | since it seems like a waste of time, but I think physically typing everything 116 | out forced me to really engage with the code. Along the way I did things like 117 | trivial refactors or renaming variables, which kind of forces me to check the 118 | semantics, since I can write down what I think the code is saying, but still 119 | have a reference for whether I got it right. 120 | 121 | For example, I actually carefully thought through how the writeback from MMA 122 | register layout to GMEM works; the layout seems complicated but I actually got 123 | this part right pretty quickly just by looking at the diagram and thinking real 124 | hard. 125 | 126 | ![WGMMA layout](images/wgmma_layout.png "WGMMA layout") 127 | 128 | I got to performance parity pretty quickly, although I did have one really dumb 129 | bug, in which I mistyped the number of registers in setmaxnregs.inc and 130 | deadlocked the kernel. It took me an hour after the kids went to bed to figure 131 | that one out 😬 132 | 133 | ## Stage 2: Applying pingpong 134 | 135 | The above kernel implements a "cooperative" warp specialized kernel where two 136 | warpgroups both work on pieces of the same 128x256 output tile. (A subtle note 137 | here: in the cooperative kernel, each warpgroup has its own 64xK slice of M and 138 | a shared Kx256 slice of N, hence cooperative. In pingpong, the two consumers 139 | don't share any input data at all -- even if they happen to be working on 140 | adjacent tiles! This seems wasteful, but it is necessary to let each warpgroup 141 | run independently and thus hide the epilogue). 142 | 143 | I had a pretty good idea from my earlier study of how to implement pingpong, so 144 | I burned down the following list of transformations: 145 | * Reduce the input tile size such that A blocks were 128x64 and B blocks were 146 | 64x128. 147 | * Make each consumer handle a 128x128 tile instead of 64x128. This had mildly 148 | trickiness, since I needed to use two groups of m64n128k16 wgmma instructions 149 | instead of one group of m64n256k16, but it wasn't that bad. 150 | * Insert ping-pong barrier to mutex the mainloop and writeback. This was 151 | moderately tricky, because I had only a hand-wavy understanding of the 152 | mbarrier lifecycle, and I needed to really solidify that. In particular, 153 | advancing the barrier by the number of K blocks (instead of by 1) is not 154 | trivial! 155 | * Make both consumers share the 128x128 output tile (under mutex), so that I 156 | could bump the pipeline stages up to 6 for maximum latency hiding. 157 | 158 | All this took a bit less than a day, and I had perf at about 80% of Cutlass. 159 | 160 | ## Stage 3: Optimize 161 | 162 | I tried a few minor optimizations, like peeling the mainloop and using 163 | `wgmma.commit_group 1` to keep the tensor core pipeline full, but those didn't 164 | help much. The elephant in the room was that I wasn't using the TMA for the 165 | writeback epilogue, so I was just doing a bunch of uncoalesced stores. 166 | 167 | Time to add TMA. 168 | 169 | Implementing a baseline use of TMA itself was not terrible; I just re-routed my 170 | uncoalesced global stores to a shared memory buffer, and then did a very naive 171 | TMA copy to write to global memory. This was pretty easy since it just moves 172 | around rectangular tensors like normal people expect. 173 | 174 | Perf improved some (5-10%) but that wasn't enough. It was time to tackle bank 175 | conflicts; `ncu` pretty clearly flagged my `st.shared` and complained about 176 | excessive wavefronts. 177 | 178 | First I wanted to switch from normal stores to `stmatrix`, since it knows how 179 | to unpack the complicated wgmma register layout into something nice. This was 180 | actually the hardest part. 181 | 182 | ![Funny comment about stmatrix](images/stmatrix_layout_funny.png "Funny comment about stmatrix layout") 183 | 184 | In retrospect, I think it's somewhat simpler than it sounds, but maybe it's 185 | just because I was finally able to get it done, and my brain has been busily 186 | purging the painful memories of getting it done. Should I even try to explain? 187 | 188 | In brief: 189 | * A single stmatrix instruction knows how to store the four 8x8 submatrices for 190 | one warp in the above wgmma diagram. 191 | * The matrices are stored by saving contiguous "packets"/rows of 8 elements. 192 | * The addresses for the 8 rows of the first submatrix come from the first 8 193 | threads (thread 0-7); the 8 rows of the second submatrix from the next 8 194 | threads (8-15); and so on. I found this to be the key; I had to stop 195 | thinking so hard about the relationship between where the data elements were 196 | and what addresses were being stored to; I just had to stop worrying, trust 197 | that the data was in the right registers, and instead worry about getting the 198 | lane-to-address mapping, and everything else worked out. 199 | 200 | That took a really, really long time. Especially because I was trying to get 201 | it done while sitting on a bench at a trampoline park while my daughter was at 202 | a friend's birthday party. 203 | 204 | But perf was still bad. Ugh. Checking `ncu`, I still had a bunch of bank 205 | conflicts. I didn't really understand this, though; in my naive imagination 206 | the stores of all the matrix data should be contiguous and life should be good. 207 | Yet life was not good. In the end I finally wrote a really simple little 208 | python script to simulate all the stmatrix address generation and analyze it 209 | for bank conflicts. Sure enough; the rectangular matrix pattern means that 210 | each warp is writing to conflicting banks. Maybe you can guess what's going on 211 | in this screenshot of my script (hint, the pipe separated columns are the 212 | "naive" layout, an 8-byte padded layout, and a swizzled layout. 213 | 214 | ![Output of stmatrix script](images/stmatrix_script.png "Output of stmatrix script") 215 | 216 | It was time to swizzle. 217 | 218 | No, wait, in desperation, I went to check if Pranjal had solved this problem 219 | for me! Kernel 12 also uses stmatrix, and avoids bank conflicts by padding 220 | each row with 8 extra bytes. I could just do that! Or I could, if the H100 221 | had more shared memory. Instead, I got a kernel launch failure. 222 | 223 | It was time to face down the infamous 128B swizzle diagram. 224 | 225 | ![TMA swizzle diagram from PTX 8.5](images/tma_swizzle_old.png "TMA swizzle diagram from PTX 8.5") 226 | 227 | Ok, I still don't know what the deal with that diagram was, but NVIDIA 228 | has greatly improved it in the PTX 8.7 ISA manual: 229 | 230 | ![TMA swizzle diagram from PTX 8.7](images/tma_swizzle.png "TMA swizzle diagram from PTX 8.7") 231 | 232 | But honestly? The docs only help so much. Even the description in terms of 233 | CuTE layouts is honestly not intuitive enough for me (thought it is rigorous, 234 | formal, and clever!). This isn't really a knock on the docs, it's just that 235 | it's hard to describe this kind of stuff verbally. 236 | 237 | You know what actually made it super obvious? Programming. I filled a shared 238 | memory buffer with consecutive integers -- basically the smem equivalent of 239 | `torch.arange(64*128).bfloat16().reshape(64, 128)`, and then TMA-transferred that 240 | to GMEM with 128B swizzling, `cudaMemcpy`ed it back to the host, and printed it 241 | out. This actually made it crystal clear! I wrote the swizzle function 242 | correctly on my first try 😄. 243 | 244 | I finished this at about 11:00 pm on Saturday and ran a benchmark, intending to 245 | go to bed no matter what, so I almost didn't realize that at this point I'd 246 | actually matched CUTLASS on a decent number of shapes. I decided to declare 247 | victory, and create this writeup. 248 | 249 | # Random Tips 250 | 251 | This is by far the most complex CUDA kernel I've ever written, so I learned a few useful things along the way: 252 | * Bind your kernel to a pytorch operator. It makes it really easy to run 253 | benchmarks or feed it test data. Although using Python adds some iteration 254 | overhead, so I also had a simple main.cu that ran some basic checks for me. 255 | * Print out the contents of shared memory. You can just call printf on the 256 | device. Sometimes that gets ugly, but I found it pretty useful. 257 | * I did a lot of debugging by writing "check patterns": arange(), eye(), 258 | etc. so that I could eyeball the outputs and see what had gone off the rails. 259 | Messed up offsets, transposition, etc. is pretty obvious when your data has 260 | structure, versus just being a sea of not-the-right-random-numbers 261 | * I also wrote several simple python scripts to iterate on indexing math and 262 | stuff. Even though I can print the same stuff from C++ or from kernels, if I 263 | just wanted to build intuition for the WGMMA register layout or something, it 264 | was more productive to iterate on a trivial python script than on the kernel 265 | itself. I've put my scripts in the GitHub repo. 266 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import sys 4 | import numpy as np 5 | 6 | import torch 7 | import triton 8 | import triton.language as tl 9 | #import triton.intraprof as proton 10 | 11 | #from matmul_persistent_tma_ws_cooperative import matmul_persistent_tma_ws_cooperative 12 | 13 | SLOTS = 3*64 14 | 15 | torch._dynamo.config.recompile_limit = 1000 16 | torch._inductor.config.max_autotune_gemm_backends = "CUTLASS" 17 | torch._inductor.config.max_autotune_gemm_search_space = "EXHAUSTIVE" 18 | torch._inductor.config.cuda.cutlass_dir = f"{os.environ['HOME']}/local/cutlass" 19 | torch._inductor.config.cuda.cutlass_op_allowlist_regex = "128x128x64_1x1x1.*pingpong_epi_tma" 20 | torch._inductor.config.cuda.cutlass_instantiation_level = "0201" 21 | 22 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 23 | 24 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 25 | 26 | HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) 27 | 28 | if HAS_TMA_DESC: 29 | print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) 30 | else: 31 | print("TMA benchmarks will be running without grid constant TMA descriptor.", ) 32 | 33 | 34 | class TmaAutoTuneHelper: 35 | 36 | # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 37 | class KernelParamWrapper: 38 | 39 | def __init__(self, desc): 40 | self.desc = desc 41 | 42 | def tma_desc_cpu_ptr(self): 43 | return self.desc.data_ptr() 44 | 45 | TMA_SIZE = 128 46 | 47 | def __init__(self): 48 | self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) 49 | self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) 50 | if HAS_TMA_DESC: 51 | self.descriptors = {} 52 | else: 53 | self.cuda_descriptors = {} 54 | 55 | # Call this method outside of the lambda function for grid size 56 | def init_tma_descriptor(self, name): 57 | if HAS_TMA_DESC: 58 | self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) 59 | else: 60 | self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) 61 | 62 | # Call this method inside the lambda function for grid size 63 | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): 64 | if HAS_TMA_DESC: 65 | desc_x = self.descriptors[name] 66 | assert desc_x.data_ptr() % 64 == 0 67 | self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) 68 | else: 69 | desc_x = self.cuda_descriptors[name] 70 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 71 | self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) 72 | desc_x.copy_(buf_x, non_blocking=True) 73 | 74 | # Call this method inside the lambda function for grid size 75 | def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): 76 | if HAS_TMA_DESC: 77 | desc_x = self.descriptors[name] 78 | assert desc_x.data_ptr() % 64 == 0 79 | self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) 80 | else: 81 | desc_x = self.cuda_descriptors[name] 82 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 83 | self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) 84 | desc_x.copy_(buf_x, non_blocking=True) 85 | 86 | def get_tma_descriptor_kernel_param(self, name): 87 | if HAS_TMA_DESC: 88 | assert self.descriptors[name] is not None 89 | return self.KernelParamWrapper(self.descriptors[name]) 90 | else: 91 | assert self.cuda_descriptors[name] is not None 92 | return self.cuda_descriptors[name] 93 | 94 | 95 | 96 | """ 97 | @triton.autotune( 98 | configs=[ 99 | triton.Config( 100 | { 101 | "BLOCK_SIZE_M": 128, 102 | "BLOCK_SIZE_N": 256, 103 | "BLOCK_SIZE_K": 64, 104 | "GROUP_SIZE_M": 8, 105 | "NUM_CONSUMER_GROUPS": 2, 106 | }, 107 | num_stages=2, 108 | num_warps=4, 109 | num_consumer_groups=2, 110 | num_buffers_warp_spec=3, 111 | ), 112 | # triton.Config( 113 | # { 114 | # "BLOCK_SIZE_M": 64, 115 | # "BLOCK_SIZE_N": 64, 116 | # "BLOCK_SIZE_K": 128, 117 | # "GROUP_SIZE_M": 8, 118 | # "NUM_CONSUMER_GROUPS": 1, 119 | # }, 120 | # num_stages=3, 121 | # num_warps=4, 122 | # num_consumer_groups=0, # disable warp specialization 123 | # num_buffers_warp_spec=3, 124 | # ), 125 | ], 126 | key=["M", "N", "K"], 127 | use_cuda_graph=True, 128 | ) 129 | """ 130 | 131 | @triton.jit 132 | def matmul_persistent_tma_ws_pingpong_kernel( 133 | a_ptr, 134 | b_ptr, 135 | c_ptr, 136 | M, 137 | N, 138 | K, 139 | #profile_mem, 140 | BLOCK_SIZE_M: tl.constexpr = 128, 141 | BLOCK_SIZE_N: tl.constexpr = 128, 142 | BLOCK_SIZE_K: tl.constexpr = 64, # 143 | GROUP_SIZE_M: tl.constexpr = 8, # 144 | NUM_CONSUMER_GROUPS: tl.constexpr= 1, 145 | ): 146 | 147 | num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) 148 | for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): 149 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 150 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 151 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 152 | group_id = pid // num_pid_in_group 153 | first_pid_m = group_id * GROUP_SIZE_M 154 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 155 | pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 156 | pid_n = (pid % num_pid_in_group) // group_size_m 157 | 158 | offs_am = pid_m * BLOCK_SIZE_M 159 | offs_bn = pid_n * BLOCK_SIZE_N 160 | 161 | offs_k0 = 0 162 | acc0 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 163 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 164 | a0 = tl._experimental_descriptor_load( 165 | a_ptr, 166 | [offs_am, offs_k0], 167 | [BLOCK_SIZE_M, BLOCK_SIZE_K], 168 | tl.bfloat16, 169 | ) 170 | b0 = tl._experimental_descriptor_load(b_ptr, [offs_bn, offs_k0], [BLOCK_SIZE_N, BLOCK_SIZE_K], tl.bfloat16) 171 | acc0 = tl.dot(a0, b0.T, acc0) 172 | offs_k0 += BLOCK_SIZE_K 173 | 174 | c0 = acc0.to(tl.bfloat16) 175 | tl._experimental_descriptor_store(c_ptr, c0, [offs_am, offs_bn]) 176 | 177 | 178 | # %% 179 | # We can now create a convenience wrapper function that only takes two input tensors, 180 | # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. 181 | 182 | BIN = None 183 | 184 | 185 | def matmul_persistent_tma_ws_pingpong(a, b, dump_chrome_trace=False): 186 | # Check constraints. 187 | assert a.shape[1] == b.shape[0], "Incompatible dimensions" 188 | assert a.dtype == b.dtype, "Incompatible dtypes" 189 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count 190 | #NUM_SMS=1 191 | M, K = a.shape 192 | K, N = b.shape 193 | dtype = a.dtype 194 | # Allocates output. 195 | c = torch.empty((M, N), device=a.device, dtype=dtype) 196 | 197 | desc_helper = TmaAutoTuneHelper() 198 | desc_helper.init_tma_descriptor("a") 199 | desc_helper.init_tma_descriptor("b") 200 | desc_helper.init_tma_descriptor("c") 201 | 202 | def grid(META): 203 | nonlocal desc_helper 204 | desc_helper.fill_2d_tma_descriptor( 205 | "a", 206 | a.data_ptr(), 207 | M, 208 | K, 209 | META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], 210 | META["BLOCK_SIZE_K"], 211 | a.element_size(), 212 | ) 213 | 214 | desc_helper.fill_2d_tma_descriptor( 215 | "b", 216 | b.data_ptr(), 217 | N, 218 | K, 219 | META["BLOCK_SIZE_N"], 220 | META["BLOCK_SIZE_K"], 221 | b.element_size(), 222 | ) 223 | desc_helper.fill_2d_tma_descriptor( 224 | "c", 225 | c.data_ptr(), 226 | M, 227 | N, 228 | META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], 229 | META["BLOCK_SIZE_N"], 230 | c.element_size(), 231 | ) 232 | return (min( 233 | NUM_SMS, 234 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 235 | ), ) 236 | 237 | desc_a = desc_helper.get_tma_descriptor_kernel_param("a") 238 | desc_b = desc_helper.get_tma_descriptor_kernel_param("b") 239 | desc_c = desc_helper.get_tma_descriptor_kernel_param("c") 240 | 241 | global BIN 242 | 243 | def gen_meta(**kwargs): 244 | return kwargs 245 | 246 | meta = gen_meta( 247 | BLOCK_SIZE_M=128, 248 | BLOCK_SIZE_N=128, 249 | BLOCK_SIZE_K=64, 250 | GROUP_SIZE_M=8, 251 | NUM_CONSUMER_GROUPS=1, 252 | num_stages=6, 253 | num_warps=4, 254 | num_consumer_groups=1, 255 | num_buffers_warp_spec=6, 256 | ) 257 | launch_grid = grid(meta) 258 | # if dump_chrome_trace: 259 | # pconfig = proton.IntraKernelConfig(num_warps=12, proton_slots=SLOTS) 260 | # proton_grid = proton.const_grid(launch_grid, autotune_configs=[], func_args={}, 261 | # num_stages=6, 262 | # num_consumer_groups=1, 263 | # num_buffers_warp_spec=6, 264 | # num_warps=4, 265 | # ) 266 | # profile_size = proton.intra_kernel_memsize(np.prod(proton_grid), pconfig) 267 | # profile_mem = torch.empty(profile_size, device="cuda", dtype=torch.uint32) 268 | # else: 269 | # profile_mem = torch.empty(1, device="cuda", dtype=torch.uint32) 270 | BIN = matmul_persistent_tma_ws_pingpong_kernel[launch_grid]( 271 | desc_a, desc_b, desc_c, # 272 | M, N, K, # 273 | #profile_mem, 274 | **meta, 275 | #proton_slots=SLOTS, 276 | ) 277 | #if dump_chrome_trace: 278 | #if True: 279 | if dump_chrome_trace: 280 | #print(profile_mem.view(-1, 4)) 281 | proton.dump_chrome_trace(NUM_SMS, pconfig, profile_mem, "chrome_trace.json", BIN) 282 | return c 283 | 284 | 285 | def aten_matmul(a, b): 286 | return a.mm(b) 287 | 288 | 289 | @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) 290 | def cutlass_matmul(a, b): 291 | return a.mm(b) 292 | 293 | torch.ops.load_library("gemm.so") 294 | 295 | def custom_gemm(a, b): 296 | return torch.ops.gemm.gemm(a, b) 297 | 298 | def custom_pingpong(a, b): 299 | return torch.ops.gemm.pingpong(a, b) 300 | 301 | def custom_stmatrix_gemm(a, b): 302 | return torch.ops.gemm.stmatrix_gemm(a, b) 303 | 304 | test_impls = [ 305 | aten_matmul, 306 | cutlass_matmul, 307 | #custom_gemm, 308 | custom_pingpong, 309 | #custom_stmatrix_gemm, 310 | #matmul_persistent_tma_ws_pingpong, 311 | ] 312 | 313 | impl_map = {fn.__name__: fn for fn in test_impls} 314 | 315 | 316 | def test(): 317 | torch.manual_seed(0) 318 | m = 4 * 11 * 64 319 | n = 3 * 12 * 256 320 | #m, n = 2 * 128, 128 321 | k = 64 * 4 322 | a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) 323 | b = torch.randn((n, k), device="cuda", dtype=torch.bfloat16).T 324 | torch_output = torch.matmul(a, b) 325 | rtol = 0 326 | for fn in test_impls: 327 | if "cutlass" in fn.__name__: 328 | continue 329 | triton_output = fn(a, b) 330 | torch.cuda.synchronize() 331 | if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): 332 | print(f" Torch matches {fn.__name__}") 333 | else: 334 | print(f" Torch DOES NOT match {fn.__name__}") 335 | print("torch output:") 336 | print(torch_output) 337 | print("triton output:") 338 | print(triton_output) 339 | 340 | 341 | #x_vals = [(8192, 8192, k) for k in range(128, 1280 + 1, 128)] 342 | #x_vals = [(6 * 11 * 128, 3 * 12 * 256, k) for k in range(640, 640 + 1, 128)] 343 | #x_vals = [(4 * 11 * 128, 2 * 12 * 256, k) for k in range(640, 640 + 1, 128)] 344 | #x_vals = [(4 * 11 * 128, 2 * 12 * 256, k) for k in range(128, 2048 + 1, 128)] 345 | 346 | #x_vals = [(6 * 11 * 128, 3 * 12 * 256, k) for k in range(128, 2048 + 1, 128)] 347 | #x_vals = [(6 * 11 * 128, 3 * 12 * 256, k) for k in range(640, 640 + 1, 128)] 348 | x_vals = [ 349 | (8192, 8192, 8192), 350 | ] 351 | x_vals = [ 352 | (6 * 11 * 128, 6 * 12 * 128, 64 * k) 353 | for k in range(1, 32) 354 | ] 355 | 356 | #[ 357 | # (6 * 11 * 128, 6 * 12 * 128, 640), 358 | # (6 * 11 * 128, 6 * 12 * 128, 1280), 359 | #] 360 | configs = [] 361 | configs.append( 362 | triton.testing.Benchmark( 363 | x_names=["K"], # Argument names to use as an x-axis for the plot 364 | x_vals=[64 * k for k in range(1, 32)], 365 | line_arg="provider", # Argument name whose value corresponds to a different line in the plot 366 | # Possible values for `line_arg` 367 | # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. 368 | line_vals=[fn.__name__ for fn in test_impls], 369 | line_names=[ 370 | "Torch (cuBLAS)", 371 | "Cutlass (no clusters)", 372 | "Custom CUDA", 373 | ], 374 | # styles=[("red", "-"), ("green", "-"), ("blue", "-")], 375 | ylabel="TFLOPS", # Label name for the y-axis 376 | plot_name="pingpong-gemm-performance-bf16", 377 | args={"M": 6 * 11 * 128, "N": 6 * 12 * 128}, 378 | )) 379 | 380 | 381 | @triton.testing.perf_report(configs) 382 | def benchmark(M, N, K, provider): 383 | a = torch.randn((M, K), device="cuda", dtype=torch.bfloat16) 384 | b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16).T 385 | quantiles = [0.5, 0.2, 0.8] 386 | fn = impl_map[provider] 387 | #ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(lambda: fn(a, b), quantiles=quantiles) 388 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: fn(a, b), quantiles=quantiles) 389 | #if provider == "matmul_ws_automatic": 390 | # print(getattr(matmul_persistent_tma_ws_cooperative_kernel, "best_config", "not autotune")) 391 | # print(BIN.asm["ttgir"]) 392 | perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) 393 | return perf(ms), perf(max_ms), perf(min_ms) 394 | #return ms, max_ms, min_ms 395 | 396 | 397 | def prof(M, N, K, provider="matmul_persistent_tma_ws_pingpong"): 398 | a = torch.randn((M, K), device="cuda", dtype=torch.bfloat16) 399 | b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16).T 400 | #kwargs = {"dump_chrome_trace": True} if provider is "matmul_ws_automatic" else {} 401 | impl_map[provider](a, b) 402 | 403 | 404 | def trace(): 405 | M, N, K = 4 * 11 * 128, 4 * 12 * 128, 640 406 | a = torch.randn((M, K), device="cuda", dtype=torch.bfloat16) 407 | b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16).T 408 | for _ in range(3): 409 | matmul_ws_automatic(a, b) 410 | 411 | g = torch.cuda.CUDAGraph() 412 | with torch.cuda.graph(g): 413 | for _ in range(10): 414 | matmul_ws_automatic(a, b) 415 | 416 | torch.cuda.synchronize() 417 | from torch.profiler import profile 418 | with profile() as p: 419 | g.replay() 420 | torch.cuda.synchronize() 421 | p.export_chrome_trace("prof.json") 422 | 423 | #test() 424 | benchmark.run(show_plots=True, print_data=True, save_path=".") 425 | #prof(6 * 11 * 128, 6 * 12 * 128, 1280, provider="cutlass_matmul") 426 | #prof(6 * 11 * 128, 6 * 12 * 128, 1280, provider="custom_pingpong") 427 | 428 | print("OK") 429 | -------------------------------------------------------------------------------- /gemm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace { 10 | 11 | using bf16 = __nv_bfloat16; 12 | 13 | void checkCudaErrors(cudaError_t error, const char* file, int line) { 14 | if (error != cudaSuccess) { 15 | fprintf( 16 | stderr, 17 | "CUDA error at %s:%d: %s\n", 18 | file, 19 | line, 20 | cudaGetErrorString(error)); 21 | exit(EXIT_FAILURE); 22 | } 23 | } 24 | 25 | #define check(err) checkCudaErrors(err, __FILE__, __LINE__) 26 | 27 | __host__ __device__ int cdiv(int m, int n) { 28 | return (m + n - 1) / n; 29 | } 30 | 31 | 32 | template 33 | void tmaPrint(T s[]) { 34 | for (int i = 0; i < 3; i++) { 35 | std::cout << " " << s[i]; 36 | } 37 | std::cout << "\n"; 38 | } 39 | 40 | __device__ inline bf16 f2bf(float v) { 41 | return __float2bfloat16(v); 42 | } 43 | 44 | __host__ static inline CUtensorMap create_tma_desc( 45 | bf16* gmem, 46 | uint32_t M, 47 | uint32_t N, 48 | uint32_t BLOCK_M, 49 | uint32_t BLOCK_N) { 50 | CUtensorMap tma_desc; 51 | // TODO: Check these requirements against the HW spec. 52 | assert(BLOCK_N >= 64); 53 | assert(N % 64 == 0); 54 | 55 | // TODO: cdiv? 56 | // TODO" why the 64 inner dim? 57 | uint64_t shape[] = {64, M, N / 64}; 58 | uint64_t stride[] = {sizeof(bf16) * N, 64 * sizeof(bf16)}; 59 | uint32_t box_shape[] = {64, BLOCK_M, BLOCK_N / 64}; 60 | uint32_t box_stride[] = {1, 1, 1}; 61 | 62 | // tmaPrint(shape); 63 | // tmaPrint(stride); 64 | // tmaPrint(box_shape); 65 | // tmaPrint(box_stride); 66 | 67 | auto result = cuTensorMapEncodeTiled( 68 | &tma_desc, 69 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 70 | 3, 71 | gmem, 72 | shape, 73 | stride, 74 | box_shape, 75 | box_stride, 76 | CU_TENSOR_MAP_INTERLEAVE_NONE, 77 | CU_TENSOR_MAP_SWIZZLE_128B, 78 | CU_TENSOR_MAP_L2_PROMOTION_NONE, 79 | CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 80 | 81 | if (result != CUDA_SUCCESS) { 82 | fprintf(stderr, "TMA desc creation failed\n"); 83 | exit(EXIT_FAILURE); 84 | } 85 | 86 | return tma_desc; 87 | } 88 | 89 | __device__ uint64_t matrix_descriptor_encode(uint64_t x) { 90 | return (x & 0x3ffff) >> 4; 91 | } 92 | 93 | __device__ uint64_t make_smem_desc(bf16* ptr) { 94 | constexpr uint64_t leading_dim_byte_offset = 16; 95 | constexpr uint64_t stride_dim_byte_offset = 1024; 96 | constexpr uint64_t swizzle_128b = 1ull; 97 | uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); 98 | return matrix_descriptor_encode(addr) | 99 | (matrix_descriptor_encode(leading_dim_byte_offset) << 16) | 100 | (matrix_descriptor_encode(stride_dim_byte_offset) << 32) | 101 | (swizzle_128b << 62); 102 | } 103 | 104 | template 105 | __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) { 106 | uint64_t desc_a = make_smem_desc(&sA[0]); 107 | uint64_t desc_b = make_smem_desc(&sB[0]); 108 | // if (threadIdx.x == 128) { 109 | 110 | // printf("%llx\n", desc_a); 111 | 112 | // printf("%llx\n", desc_b); 113 | // } 114 | 115 | #if 1 116 | asm volatile( 117 | "{\n" 118 | "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " 119 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 120 | " %8, %9, %10, %11, %12, %13, %14, %15, " 121 | " %16, %17, %18, %19, %20, %21, %22, %23, " 122 | " %24, %25, %26, %27, %28, %29, %30, %31, " 123 | " %32, %33, %34, %35, %36, %37, %38, %39, " 124 | " %40, %41, %42, %43, %44, %45, %46, %47, " 125 | " %48, %49, %50, %51, %52, %53, %54, %55, " 126 | " %56, %57, %58, %59, %60, %61, %62, %63, " 127 | " %64, %65, %66, %67, %68, %69, %70, %71, " 128 | " %72, %73, %74, %75, %76, %77, %78, %79, " 129 | " %80, %81, %82, %83, %84, %85, %86, %87, " 130 | " %88, %89, %90, %91, %92, %93, %94, %95, " 131 | " %96, %97, %98, %99, %100, %101, %102, %103, " 132 | " %104, %105, %106, %107, %108, %109, %110, %111, " 133 | " %112, %113, %114, %115, %116, %117, %118, %119, " 134 | " %120, %121, %122, %123, %124, %125, %126, %127}," 135 | " %128," 136 | " %129," 137 | " %130, %131, %132, %133, %134;\n" 138 | "}\n" 139 | : "+f"(d[0][0]), 140 | "+f"(d[0][1]), 141 | "+f"(d[0][2]), 142 | "+f"(d[0][3]), 143 | "+f"(d[0][4]), 144 | "+f"(d[0][5]), 145 | "+f"(d[0][6]), 146 | "+f"(d[0][7]), 147 | "+f"(d[1][0]), 148 | "+f"(d[1][1]), 149 | "+f"(d[1][2]), 150 | "+f"(d[1][3]), 151 | "+f"(d[1][4]), 152 | "+f"(d[1][5]), 153 | "+f"(d[1][6]), 154 | "+f"(d[1][7]), 155 | "+f"(d[2][0]), 156 | "+f"(d[2][1]), 157 | "+f"(d[2][2]), 158 | "+f"(d[2][3]), 159 | "+f"(d[2][4]), 160 | "+f"(d[2][5]), 161 | "+f"(d[2][6]), 162 | "+f"(d[2][7]), 163 | "+f"(d[3][0]), 164 | "+f"(d[3][1]), 165 | "+f"(d[3][2]), 166 | "+f"(d[3][3]), 167 | "+f"(d[3][4]), 168 | "+f"(d[3][5]), 169 | "+f"(d[3][6]), 170 | "+f"(d[3][7]), 171 | "+f"(d[4][0]), 172 | "+f"(d[4][1]), 173 | "+f"(d[4][2]), 174 | "+f"(d[4][3]), 175 | "+f"(d[4][4]), 176 | "+f"(d[4][5]), 177 | "+f"(d[4][6]), 178 | "+f"(d[4][7]), 179 | "+f"(d[5][0]), 180 | "+f"(d[5][1]), 181 | "+f"(d[5][2]), 182 | "+f"(d[5][3]), 183 | "+f"(d[5][4]), 184 | "+f"(d[5][5]), 185 | "+f"(d[5][6]), 186 | "+f"(d[5][7]), 187 | "+f"(d[6][0]), 188 | "+f"(d[6][1]), 189 | "+f"(d[6][2]), 190 | "+f"(d[6][3]), 191 | "+f"(d[6][4]), 192 | "+f"(d[6][5]), 193 | "+f"(d[6][6]), 194 | "+f"(d[6][7]), 195 | "+f"(d[7][0]), 196 | "+f"(d[7][1]), 197 | "+f"(d[7][2]), 198 | "+f"(d[7][3]), 199 | "+f"(d[7][4]), 200 | "+f"(d[7][5]), 201 | "+f"(d[7][6]), 202 | "+f"(d[7][7]), 203 | "+f"(d[8][0]), 204 | "+f"(d[8][1]), 205 | "+f"(d[8][2]), 206 | "+f"(d[8][3]), 207 | "+f"(d[8][4]), 208 | "+f"(d[8][5]), 209 | "+f"(d[8][6]), 210 | "+f"(d[8][7]), 211 | "+f"(d[9][0]), 212 | "+f"(d[9][1]), 213 | "+f"(d[9][2]), 214 | "+f"(d[9][3]), 215 | "+f"(d[9][4]), 216 | "+f"(d[9][5]), 217 | "+f"(d[9][6]), 218 | "+f"(d[9][7]), 219 | "+f"(d[10][0]), 220 | "+f"(d[10][1]), 221 | "+f"(d[10][2]), 222 | "+f"(d[10][3]), 223 | "+f"(d[10][4]), 224 | "+f"(d[10][5]), 225 | "+f"(d[10][6]), 226 | "+f"(d[10][7]), 227 | "+f"(d[11][0]), 228 | "+f"(d[11][1]), 229 | "+f"(d[11][2]), 230 | "+f"(d[11][3]), 231 | "+f"(d[11][4]), 232 | "+f"(d[11][5]), 233 | "+f"(d[11][6]), 234 | "+f"(d[11][7]), 235 | "+f"(d[12][0]), 236 | "+f"(d[12][1]), 237 | "+f"(d[12][2]), 238 | "+f"(d[12][3]), 239 | "+f"(d[12][4]), 240 | "+f"(d[12][5]), 241 | "+f"(d[12][6]), 242 | "+f"(d[12][7]), 243 | "+f"(d[13][0]), 244 | "+f"(d[13][1]), 245 | "+f"(d[13][2]), 246 | "+f"(d[13][3]), 247 | "+f"(d[13][4]), 248 | "+f"(d[13][5]), 249 | "+f"(d[13][6]), 250 | "+f"(d[13][7]), 251 | "+f"(d[14][0]), 252 | "+f"(d[14][1]), 253 | "+f"(d[14][2]), 254 | "+f"(d[14][3]), 255 | "+f"(d[14][4]), 256 | "+f"(d[14][5]), 257 | "+f"(d[14][6]), 258 | "+f"(d[14][7]), 259 | "+f"(d[15][0]), 260 | "+f"(d[15][1]), 261 | "+f"(d[15][2]), 262 | "+f"(d[15][3]), 263 | "+f"(d[15][4]), 264 | "+f"(d[15][5]), 265 | "+f"(d[15][6]), 266 | "+f"(d[15][7]) 267 | : "l"(desc_a), 268 | "l"(desc_b), 269 | "n"(int32_t(ScaleD)), 270 | "n"(int32_t(ScaleA)), 271 | "n"(int32_t(ScaleB)), 272 | "n"(int32_t(TransA)), 273 | "n"(int32_t(TransB))); 274 | #endif 275 | } 276 | 277 | __device__ void wgmma_commit_group() { 278 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 279 | } 280 | 281 | template 282 | __device__ void wgmma_wait_group() { 283 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); 284 | } 285 | 286 | __device__ void wgmma_fence() { 287 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 288 | } 289 | 290 | template 291 | __device__ static __forceinline__ void setmaxnreg_inc() { 292 | asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(REGS)); 293 | } 294 | 295 | template 296 | __device__ static void __forceinline__ setmaxnreg_dec() { 297 | asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(REGS)); 298 | } 299 | 300 | __device__ static void __forceinline__ 301 | init_barrier(uint64_t* bar, int thread_count, int transaction_count) { 302 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 303 | asm volatile( 304 | "mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(bar_ptr), 305 | "r"(thread_count + transaction_count)); 306 | } 307 | 308 | __device__ static void __forceinline__ wait_barrier(uint64_t* bar, int phase) { 309 | uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); 310 | asm volatile( 311 | "{\n" 312 | ".reg .pred P1;\n" 313 | "LAB_WAIT:\n" 314 | "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" 315 | "@P1 bra.uni DONE;\n" 316 | "bra.uni LAB_WAIT;\n" 317 | "DONE:\n" 318 | "}\n" ::"r"(mbar_ptr), 319 | "r"(phase)); 320 | } 321 | 322 | __device__ static void __forceinline__ 323 | arrive_barrier(uint64_t* bar, int count) { 324 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 325 | asm volatile( 326 | "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" ::"r"( 327 | bar_ptr), 328 | "r"(count) 329 | : "memory"); 330 | } 331 | 332 | __device__ static void __forceinline__ 333 | expect_bytes(uint64_t* bar, uint32_t bytes) { 334 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 335 | asm volatile( 336 | "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" ::"r"(bar_ptr), 337 | "r"(bytes)); 338 | } 339 | 340 | __device__ static void __forceinline__ tma_load( 341 | bf16* dst, 342 | void const* const src_tma_desc, 343 | uint64_t* bar, 344 | int n, 345 | int m) { 346 | uint64_t tma_ptr = reinterpret_cast(src_tma_desc); 347 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 348 | uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); 349 | asm volatile( 350 | "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" 351 | " [%0], [%1, {%3, %4, %5}], [%2];" 352 | :: 353 | "r"(dst_ptr), 354 | "l"(tma_ptr), 355 | "r"(bar_ptr), 356 | "n"(0), 357 | "r"(m), 358 | "r"(n / 64) 359 | : "memory"); 360 | } 361 | 362 | constexpr int BLOCK_M = 128; 363 | constexpr int BLOCK_N = 256; 364 | constexpr int BLOCK_K = 64; 365 | constexpr int NUM_SMS = 132; 366 | constexpr int STAGES = 3; 367 | 368 | constexpr int WARPGROUP_SIZE = 128; 369 | constexpr int WARPGROUPS = 3; 370 | constexpr int NUM_THREADS = WARPGROUPS * WARPGROUP_SIZE; 371 | 372 | struct SharedStorage { 373 | alignas(128) bf16 A[BLOCK_M * BLOCK_K * STAGES]; 374 | alignas(128) bf16 B[BLOCK_K * BLOCK_N * STAGES]; 375 | }; 376 | 377 | __global__ __launch_bounds__(NUM_THREADS) void gemm( 378 | const __grid_constant__ CUtensorMap A, 379 | const __grid_constant__ CUtensorMap B, 380 | bf16* C, 381 | int M, 382 | int N, 383 | int K) { 384 | // Producer buffers for A and B. 385 | extern __shared__ __align__(128) uint8_t dynamic_smem[]; 386 | SharedStorage& smem = *reinterpret_cast(dynamic_smem); 387 | 388 | // Barriers. 389 | __shared__ __align__(8) uint64_t prod[STAGES]; 390 | __shared__ __align__(8) uint64_t cons[STAGES]; 391 | 392 | int tid = threadIdx.x; 393 | int wgid = tid / WARPGROUP_SIZE; 394 | int wg_tid = tid % WARPGROUP_SIZE; 395 | 396 | // Init barriers. 397 | if (tid == 0) { 398 | for (int i = 0; i < STAGES; i++) { 399 | init_barrier(&prod[i], 0, 1); 400 | init_barrier(&cons[i], 0, WARPGROUPS - 1); 401 | } 402 | } 403 | __syncthreads(); 404 | 405 | auto m_blocks = cdiv(M, BLOCK_M); 406 | auto n_blocks = cdiv(N, BLOCK_N); 407 | auto k_blocks = cdiv(K, BLOCK_K); 408 | 409 | if (wgid == 0) { 410 | // Producer warpgroup. 411 | setmaxnreg_dec<40>(); 412 | // Mainloop. 413 | 414 | //int m = 0, n = 0; 415 | if (wg_tid == 0) { 416 | int phase = 0; 417 | int stage = 0; 418 | for (auto bid = blockIdx.x; bid < m_blocks * n_blocks; bid += gridDim.x) { 419 | auto m = bid / n_blocks; 420 | auto n = bid % n_blocks; 421 | 422 | for (int k = 0; k < k_blocks ; k++) { 423 | // Wait for consumer. 424 | // TODO: stage and phase update. 425 | wait_barrier(&cons[stage], phase); 426 | // Set expect bytes for TMA. 427 | expect_bytes( 428 | &prod[stage], sizeof(bf16) * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)); 429 | // Load A. 430 | // TODO: use proper stage 431 | tma_load(&smem.A[stage * BLOCK_K * BLOCK_M], &A, &prod[stage], k * BLOCK_K, m * BLOCK_M); 432 | // Load B. 433 | tma_load(&smem.B[stage * BLOCK_K * BLOCK_N], &B, &prod[stage], k * BLOCK_K, n * BLOCK_N); 434 | stage++; 435 | if (stage == STAGES) { 436 | stage = 0; 437 | phase ^= 1; 438 | } 439 | } 440 | } 441 | } 442 | } else { 443 | // Consumer warpgroup. 444 | setmaxnreg_inc<232>(); 445 | 446 | int stage = 0; 447 | int phase = 0; 448 | if (wg_tid == 0) { 449 | for (int i = 0; i < STAGES; i++) { 450 | arrive_barrier(&cons[i], 1); 451 | } 452 | } 453 | for (auto bid = blockIdx.x; bid < m_blocks * n_blocks; bid += gridDim.x) { 454 | auto m = bid / n_blocks; 455 | auto n = bid % n_blocks; 456 | 457 | float acc[16][8]; 458 | memset(acc, 0, sizeof(acc)); 459 | // Mainloop. 460 | for (int k = 0; k < k_blocks; k++) { 461 | // Wait for producer. 462 | wait_barrier(&prod[stage], phase); 463 | 464 | wgmma_fence(); 465 | 466 | #pragma unroll 467 | for (int mma_k = 0; mma_k < BLOCK_K; mma_k += 16) { 468 | wgmma256<1, 1, 1, 0, 0>( 469 | acc, 470 | &smem.A[stage * BLOCK_M * BLOCK_K + mma_k + (wgid - 1) * BLOCK_K * (BLOCK_M / 2)], 471 | &smem.B[stage * BLOCK_N * BLOCK_K + mma_k]); 472 | } 473 | 474 | wgmma_commit_group(); 475 | wgmma_wait_group<0>(); 476 | 477 | // Arrive at consumer. 478 | if (wg_tid == 0) { 479 | arrive_barrier(&cons[stage], 1); 480 | } 481 | stage++; 482 | if (stage == STAGES) { 483 | stage = 0; 484 | phase ^= 1; 485 | } 486 | } 487 | // Write back to gmem. 488 | auto warp = wg_tid / 32; 489 | auto lane = wg_tid % 32; 490 | auto row = warp * 16 + lane / 4; 491 | auto col = (wg_tid % 4) * 2; 492 | 493 | row += (wgid - 1) * 64; 494 | auto C_BLOCK = &C[m * BLOCK_M + n * BLOCK_N * M]; 495 | 496 | //printf("%d %d %d\n", tid - 128, row, col); 497 | for (int inst_n = 0; inst_n < 256; inst_n += 16) { 498 | #define Cidx(r, c) C_BLOCK[(r) + ((c) * M)] 499 | // clang-format off 500 | // printf("%d %d %d %f\n", 501 | // tid, 502 | // row, 503 | // col, 504 | // acc[n][0]); 505 | Cidx(row, inst_n + col ) = f2bf(acc[inst_n / 16][0]); 506 | Cidx(row, inst_n + col + 1) = f2bf(acc[inst_n / 16][1]); 507 | Cidx(row + 8, inst_n + col ) = f2bf(acc[inst_n / 16][2]); 508 | Cidx(row + 8, inst_n + col + 1) = f2bf(acc[inst_n / 16][3]); 509 | Cidx(row, inst_n + col + 8) = f2bf(acc[inst_n / 16][4]); 510 | Cidx(row, inst_n + col + 9) = f2bf(acc[inst_n / 16][5]); 511 | Cidx(row + 8, inst_n + col + 8) = f2bf(acc[inst_n / 16][6]); 512 | Cidx(row + 8, inst_n + col + 9) = f2bf(acc[inst_n / 16][7]); 513 | // clang-format on 514 | } 515 | } 516 | 517 | // auto row = (wg_tid / 32) * 2 + wg_tid / 4; 518 | // if (tid == 128) { 519 | // for (int i = 0; i < 16; i++) { 520 | // for (int j = 0; j < 8; j++) { 521 | // printf(" %6.2f", acc[i][j]); 522 | // } 523 | // printf("\n"); 524 | // } 525 | // printf("\n"); 526 | // } 527 | } 528 | // __syncthreads(); 529 | // if (tid == 128) { 530 | // printf("smem.A:\n"); 531 | // for (int i = 0; i < BLOCK_M; i++) { 532 | // for (int j = 0; j < BLOCK_K; j++) { 533 | // printf(" %6.2f", __bfloat162float(smem.A[i * BLOCK_K + j])); 534 | // } 535 | // printf("\n"); 536 | // } 537 | // printf("\n"); 538 | // printf("smem.B:\n"); 539 | // for (int i = 0; i < BLOCK_K; i++) { 540 | // for (int j = 0; j < BLOCK_N; j++) { 541 | // printf(" %6.2f", __bfloat162float(smem.B[i + j * BLOCK_K])); 542 | // } 543 | // printf("\n"); 544 | // } 545 | // printf("\n"); 546 | // } 547 | } 548 | } // namespace 549 | 550 | void run_gemm(bf16* A, bf16* B, bf16* C, int M, int N, int K) { 551 | // Compute necessary shared memory for buffers. 552 | size_t smem_size = sizeof(SharedStorage); 553 | check(cudaFuncSetAttribute( 554 | gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 555 | 556 | // Set up TMA descriptors 557 | auto descA = create_tma_desc(A, M, K, BLOCK_M, BLOCK_K); 558 | auto descB = create_tma_desc(B, N, K, BLOCK_N, BLOCK_K); 559 | 560 | // Launch kernel! 561 | gemm<<>>(descA, descB, C, M, N, K); 562 | } 563 | 564 | void run_gemm(void* A, void* B, void* C, int M, int N, int K) { 565 | run_gemm((bf16*) A, (bf16*)B, (bf16*)C, M, N, K); 566 | } 567 | -------------------------------------------------------------------------------- /stmatrix.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace { 10 | 11 | using bf16 = __nv_bfloat16; 12 | 13 | void checkCudaErrors(cudaError_t error, const char* file, int line) { 14 | if (error != cudaSuccess) { 15 | fprintf( 16 | stderr, 17 | "CUDA error at %s:%d: %s\n", 18 | file, 19 | line, 20 | cudaGetErrorString(error)); 21 | exit(EXIT_FAILURE); 22 | } 23 | } 24 | 25 | #define check(err) checkCudaErrors(err, __FILE__, __LINE__) 26 | 27 | __host__ __device__ int cdiv(int m, int n) { 28 | return (m + n - 1) / n; 29 | } 30 | 31 | 32 | template 33 | void tmaPrint(T s[]) { 34 | for (int i = 0; i < 3; i++) { 35 | std::cout << " " << s[i]; 36 | } 37 | std::cout << "\n"; 38 | } 39 | 40 | __device__ inline bf16 f2bf(float v) { 41 | return __float2bfloat16(v); 42 | } 43 | 44 | __host__ static inline CUtensorMap create_tma_desc( 45 | bf16* gmem, 46 | uint32_t M, 47 | uint32_t N, 48 | uint32_t BLOCK_M, 49 | uint32_t BLOCK_N, 50 | bool swizzle, 51 | bool padding) { 52 | CUtensorMap tma_desc; 53 | // TODO: Check these requirements against the HW spec. 54 | assert(BLOCK_N >= 64); 55 | assert(N % 64 == 0); 56 | 57 | // TODO: cdiv? 58 | // TODO" why the 64 inner dim? 59 | uint64_t shape[5] = {64, M, N / 64}; 60 | uint64_t stride[5] = {sizeof(bf16) * N, 64 * sizeof(bf16)}; 61 | uint32_t box_shape[5] = {padding ? 72 : 64, BLOCK_M, BLOCK_N / 64}; 62 | uint32_t box_stride[5] = {1, 1, 1}; 63 | 64 | //for (int i = 0; i < 5; i++) 65 | // tmaPrint(shape); 66 | // tmaPrint(stride); 67 | // tmaPrint(box_shape); 68 | // tmaPrint(box_stride); 69 | 70 | auto result = cuTensorMapEncodeTiled( 71 | &tma_desc, 72 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 73 | 3, 74 | gmem, 75 | shape, 76 | stride, 77 | box_shape, 78 | box_stride, 79 | CU_TENSOR_MAP_INTERLEAVE_NONE, 80 | swizzle ? CU_TENSOR_MAP_SWIZZLE_128B : CU_TENSOR_MAP_SWIZZLE_NONE, 81 | CU_TENSOR_MAP_L2_PROMOTION_NONE, 82 | CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 83 | 84 | if (result != CUDA_SUCCESS) { 85 | fprintf(stderr, "TMA desc creation failed\n"); 86 | exit(EXIT_FAILURE); 87 | } 88 | 89 | return tma_desc; 90 | } 91 | 92 | __device__ uint64_t matrix_descriptor_encode(uint64_t x) { 93 | return (x & 0x3ffff) >> 4; 94 | } 95 | 96 | __device__ uint64_t make_smem_desc(bf16* ptr) { 97 | constexpr uint64_t leading_dim_byte_offset = 16; 98 | constexpr uint64_t stride_dim_byte_offset = 1024; 99 | constexpr uint64_t swizzle_128b = 1ull; 100 | uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); 101 | return matrix_descriptor_encode(addr) | 102 | (matrix_descriptor_encode(leading_dim_byte_offset) << 16) | 103 | (matrix_descriptor_encode(stride_dim_byte_offset) << 32) | 104 | (swizzle_128b << 62); 105 | } 106 | 107 | template 108 | __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) { 109 | uint64_t desc_a = make_smem_desc(&sA[0]); 110 | uint64_t desc_b = make_smem_desc(&sB[0]); 111 | // if (threadIdx.x == 128) { 112 | 113 | // printf("%llx\n", desc_a); 114 | 115 | // printf("%llx\n", desc_b); 116 | // } 117 | 118 | #if 1 119 | asm volatile( 120 | "{\n" 121 | "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " 122 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 123 | " %8, %9, %10, %11, %12, %13, %14, %15, " 124 | " %16, %17, %18, %19, %20, %21, %22, %23, " 125 | " %24, %25, %26, %27, %28, %29, %30, %31, " 126 | " %32, %33, %34, %35, %36, %37, %38, %39, " 127 | " %40, %41, %42, %43, %44, %45, %46, %47, " 128 | " %48, %49, %50, %51, %52, %53, %54, %55, " 129 | " %56, %57, %58, %59, %60, %61, %62, %63, " 130 | " %64, %65, %66, %67, %68, %69, %70, %71, " 131 | " %72, %73, %74, %75, %76, %77, %78, %79, " 132 | " %80, %81, %82, %83, %84, %85, %86, %87, " 133 | " %88, %89, %90, %91, %92, %93, %94, %95, " 134 | " %96, %97, %98, %99, %100, %101, %102, %103, " 135 | " %104, %105, %106, %107, %108, %109, %110, %111, " 136 | " %112, %113, %114, %115, %116, %117, %118, %119, " 137 | " %120, %121, %122, %123, %124, %125, %126, %127}," 138 | " %128," 139 | " %129," 140 | " %130, %131, %132, %133, %134;\n" 141 | "}\n" 142 | : "+f"(d[0][0]), 143 | "+f"(d[0][1]), 144 | "+f"(d[0][2]), 145 | "+f"(d[0][3]), 146 | "+f"(d[0][4]), 147 | "+f"(d[0][5]), 148 | "+f"(d[0][6]), 149 | "+f"(d[0][7]), 150 | "+f"(d[1][0]), 151 | "+f"(d[1][1]), 152 | "+f"(d[1][2]), 153 | "+f"(d[1][3]), 154 | "+f"(d[1][4]), 155 | "+f"(d[1][5]), 156 | "+f"(d[1][6]), 157 | "+f"(d[1][7]), 158 | "+f"(d[2][0]), 159 | "+f"(d[2][1]), 160 | "+f"(d[2][2]), 161 | "+f"(d[2][3]), 162 | "+f"(d[2][4]), 163 | "+f"(d[2][5]), 164 | "+f"(d[2][6]), 165 | "+f"(d[2][7]), 166 | "+f"(d[3][0]), 167 | "+f"(d[3][1]), 168 | "+f"(d[3][2]), 169 | "+f"(d[3][3]), 170 | "+f"(d[3][4]), 171 | "+f"(d[3][5]), 172 | "+f"(d[3][6]), 173 | "+f"(d[3][7]), 174 | "+f"(d[4][0]), 175 | "+f"(d[4][1]), 176 | "+f"(d[4][2]), 177 | "+f"(d[4][3]), 178 | "+f"(d[4][4]), 179 | "+f"(d[4][5]), 180 | "+f"(d[4][6]), 181 | "+f"(d[4][7]), 182 | "+f"(d[5][0]), 183 | "+f"(d[5][1]), 184 | "+f"(d[5][2]), 185 | "+f"(d[5][3]), 186 | "+f"(d[5][4]), 187 | "+f"(d[5][5]), 188 | "+f"(d[5][6]), 189 | "+f"(d[5][7]), 190 | "+f"(d[6][0]), 191 | "+f"(d[6][1]), 192 | "+f"(d[6][2]), 193 | "+f"(d[6][3]), 194 | "+f"(d[6][4]), 195 | "+f"(d[6][5]), 196 | "+f"(d[6][6]), 197 | "+f"(d[6][7]), 198 | "+f"(d[7][0]), 199 | "+f"(d[7][1]), 200 | "+f"(d[7][2]), 201 | "+f"(d[7][3]), 202 | "+f"(d[7][4]), 203 | "+f"(d[7][5]), 204 | "+f"(d[7][6]), 205 | "+f"(d[7][7]), 206 | "+f"(d[8][0]), 207 | "+f"(d[8][1]), 208 | "+f"(d[8][2]), 209 | "+f"(d[8][3]), 210 | "+f"(d[8][4]), 211 | "+f"(d[8][5]), 212 | "+f"(d[8][6]), 213 | "+f"(d[8][7]), 214 | "+f"(d[9][0]), 215 | "+f"(d[9][1]), 216 | "+f"(d[9][2]), 217 | "+f"(d[9][3]), 218 | "+f"(d[9][4]), 219 | "+f"(d[9][5]), 220 | "+f"(d[9][6]), 221 | "+f"(d[9][7]), 222 | "+f"(d[10][0]), 223 | "+f"(d[10][1]), 224 | "+f"(d[10][2]), 225 | "+f"(d[10][3]), 226 | "+f"(d[10][4]), 227 | "+f"(d[10][5]), 228 | "+f"(d[10][6]), 229 | "+f"(d[10][7]), 230 | "+f"(d[11][0]), 231 | "+f"(d[11][1]), 232 | "+f"(d[11][2]), 233 | "+f"(d[11][3]), 234 | "+f"(d[11][4]), 235 | "+f"(d[11][5]), 236 | "+f"(d[11][6]), 237 | "+f"(d[11][7]), 238 | "+f"(d[12][0]), 239 | "+f"(d[12][1]), 240 | "+f"(d[12][2]), 241 | "+f"(d[12][3]), 242 | "+f"(d[12][4]), 243 | "+f"(d[12][5]), 244 | "+f"(d[12][6]), 245 | "+f"(d[12][7]), 246 | "+f"(d[13][0]), 247 | "+f"(d[13][1]), 248 | "+f"(d[13][2]), 249 | "+f"(d[13][3]), 250 | "+f"(d[13][4]), 251 | "+f"(d[13][5]), 252 | "+f"(d[13][6]), 253 | "+f"(d[13][7]), 254 | "+f"(d[14][0]), 255 | "+f"(d[14][1]), 256 | "+f"(d[14][2]), 257 | "+f"(d[14][3]), 258 | "+f"(d[14][4]), 259 | "+f"(d[14][5]), 260 | "+f"(d[14][6]), 261 | "+f"(d[14][7]), 262 | "+f"(d[15][0]), 263 | "+f"(d[15][1]), 264 | "+f"(d[15][2]), 265 | "+f"(d[15][3]), 266 | "+f"(d[15][4]), 267 | "+f"(d[15][5]), 268 | "+f"(d[15][6]), 269 | "+f"(d[15][7]) 270 | : "l"(desc_a), 271 | "l"(desc_b), 272 | "n"(int32_t(ScaleD)), 273 | "n"(int32_t(ScaleA)), 274 | "n"(int32_t(ScaleB)), 275 | "n"(int32_t(TransA)), 276 | "n"(int32_t(TransB))); 277 | #endif 278 | } 279 | 280 | __device__ void wgmma_commit_group() { 281 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 282 | } 283 | 284 | template 285 | __device__ void wgmma_wait_group() { 286 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); 287 | } 288 | 289 | __device__ void wgmma_fence() { 290 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 291 | } 292 | 293 | template 294 | __device__ static __forceinline__ void setmaxnreg_inc() { 295 | asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(REGS)); 296 | } 297 | 298 | template 299 | __device__ static void __forceinline__ setmaxnreg_dec() { 300 | asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(REGS)); 301 | } 302 | 303 | __device__ static void __forceinline__ 304 | init_barrier(uint64_t* bar, int thread_count, int transaction_count) { 305 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 306 | asm volatile( 307 | "mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(bar_ptr), 308 | "r"(thread_count + transaction_count)); 309 | } 310 | 311 | __device__ static void __forceinline__ wait_barrier(uint64_t* bar, int phase) { 312 | uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); 313 | asm volatile( 314 | "{\n" 315 | ".reg .pred P1;\n" 316 | "LAB_WAIT:\n" 317 | "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" 318 | "@P1 bra.uni DONE;\n" 319 | "bra.uni LAB_WAIT;\n" 320 | "DONE:\n" 321 | "}\n" ::"r"(mbar_ptr), 322 | "r"(phase)); 323 | } 324 | 325 | __device__ static void __forceinline__ 326 | arrive_barrier(uint64_t* bar, int count) { 327 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 328 | asm volatile( 329 | "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" ::"r"( 330 | bar_ptr), 331 | "r"(count) 332 | : "memory"); 333 | } 334 | 335 | __device__ static void __forceinline__ 336 | expect_bytes(uint64_t* bar, uint32_t bytes) { 337 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 338 | asm volatile( 339 | "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" ::"r"(bar_ptr), 340 | "r"(bytes)); 341 | } 342 | 343 | __device__ static void __forceinline__ tma_load( 344 | bf16* dst, 345 | void const* const src_tma_desc, 346 | uint64_t* bar, 347 | int n, 348 | int m) { 349 | uint64_t tma_ptr = reinterpret_cast(src_tma_desc); 350 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 351 | uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); 352 | asm volatile( 353 | "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" 354 | " [%0], [%1, {%3, %4, %5}], [%2];" 355 | :: 356 | "r"(dst_ptr), 357 | "l"(tma_ptr), 358 | "r"(bar_ptr), 359 | "n"(0), 360 | "r"(m), 361 | "r"(n / 64) 362 | : "memory"); 363 | } 364 | 365 | __device__ static void tma_store(void const* dst_tma_desc, bf16* src, int N, int M) { 366 | uint64_t tma_ptr = reinterpret_cast(dst_tma_desc); 367 | uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); 368 | asm volatile( 369 | "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group" 370 | " [%0, {%2, %3, %4}], [%1];" 371 | :: "l"(tma_ptr), "r"(src_ptr), "n"(0), "r"(M), "r"(N / 64) 372 | : "memory"); 373 | } 374 | 375 | template 376 | __device__ static void tma_wait_group() { 377 | asm volatile("cp.async.bulk.wait_group %0;" :: "n"(N)); 378 | } 379 | 380 | __device__ static void tma_commit_group() { 381 | asm volatile("cp.async.bulk.commit_group;"); 382 | } 383 | 384 | __device__ static void stmatrix(bf16* smem_ptr, bf16 src[8]) { 385 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); 386 | uint32_t* d = reinterpret_cast(src); 387 | asm volatile( 388 | "stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%0], {%1, %2, %3, %4};" 389 | :: "r"(smem), "r"(d[0]), "r"(d[1]), "r"(d[2]), "r"(d[3])); 390 | } 391 | 392 | constexpr int INST_N = 256; 393 | constexpr int BLOCK_M = 128; 394 | constexpr int BLOCK_N = 256; 395 | constexpr int BLOCK_K = 64; 396 | constexpr int NUM_SMS = 132; 397 | constexpr int STAGES = 3; 398 | 399 | constexpr int WARPGROUP_SIZE = 128; 400 | constexpr int NUM_CONSUMERS = 2; 401 | constexpr int WARPGROUPS = 1 + NUM_CONSUMERS; 402 | constexpr int NUM_THREADS = WARPGROUPS * WARPGROUP_SIZE; 403 | 404 | struct SharedStorage { 405 | alignas(128) bf16 A[BLOCK_M * BLOCK_K * STAGES]; 406 | alignas(128) bf16 B[BLOCK_K * BLOCK_N * STAGES]; 407 | //alignas(128) bf16 C[BLOCK_N * (BLOCK_M + (BLOCK_M / 64) * 8)]; // padding of 8 elements per consumer 408 | //alignas(128) bf16 C[BLOCK_N * BLOCK_M]; // padding of 8 elements per consumer 409 | alignas(128) bf16 C[BLOCK_N * (BLOCK_M + NUM_CONSUMERS * 8)]; // padding of 8 elements per consumer 410 | }; 411 | 412 | __global__ __launch_bounds__(NUM_THREADS) void gemm( 413 | const __grid_constant__ CUtensorMap A, 414 | const __grid_constant__ CUtensorMap B, 415 | const __grid_constant__ CUtensorMap C, 416 | int M, 417 | int N, 418 | int K) { 419 | // Producer buffers for A and B. 420 | extern __shared__ __align__(128) uint8_t dynamic_smem[]; 421 | SharedStorage& smem = *reinterpret_cast(dynamic_smem); 422 | 423 | // Barriers. 424 | __shared__ __align__(8) uint64_t prod[STAGES]; 425 | __shared__ __align__(8) uint64_t cons[STAGES]; 426 | 427 | int tid = threadIdx.x; 428 | int wgid = tid / WARPGROUP_SIZE; 429 | int wg_tid = tid % WARPGROUP_SIZE; 430 | 431 | // Init barriers. 432 | if (tid == 0) { 433 | for (int i = 0; i < STAGES; i++) { 434 | init_barrier(&prod[i], 0, 1); 435 | init_barrier(&cons[i], 0, WARPGROUPS - 1); 436 | } 437 | } 438 | __syncthreads(); 439 | 440 | auto m_blocks = cdiv(M, BLOCK_M); 441 | auto n_blocks = cdiv(N, BLOCK_N); 442 | auto k_blocks = cdiv(K, BLOCK_K); 443 | 444 | if (wgid == 0) { 445 | // Producer warpgroup. 446 | setmaxnreg_dec<40>(); 447 | // Mainloop. 448 | 449 | //int m = 0, n = 0; 450 | if (wg_tid == 0) { 451 | int phase = 0; 452 | int stage = 0; 453 | for (auto bid = blockIdx.x; bid < m_blocks * n_blocks; bid += gridDim.x) { 454 | auto m = bid / n_blocks; 455 | auto n = bid % n_blocks; 456 | 457 | for (int k = 0; k < k_blocks ; k++) { 458 | // Wait for consumer. 459 | // TODO: stage and phase update. 460 | wait_barrier(&cons[stage], phase); 461 | // Set expect bytes for TMA. 462 | expect_bytes( 463 | &prod[stage], sizeof(bf16) * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)); 464 | // Load A. 465 | // TODO: use proper stage 466 | tma_load(&smem.A[stage * BLOCK_K * BLOCK_M], &A, &prod[stage], k * BLOCK_K, m * BLOCK_M); 467 | // Load B. 468 | tma_load(&smem.B[stage * BLOCK_K * BLOCK_N], &B, &prod[stage], k * BLOCK_K, n * BLOCK_N); 469 | stage++; 470 | if (stage == STAGES) { 471 | stage = 0; 472 | phase ^= 1; 473 | } 474 | } 475 | } 476 | } 477 | } else { 478 | // Consumer warpgroup. 479 | setmaxnreg_inc<232>(); 480 | 481 | int stage = 0; 482 | int phase = 0; 483 | if (wg_tid == 0) { 484 | for (int i = 0; i < STAGES; i++) { 485 | arrive_barrier(&cons[i], 1); 486 | } 487 | } 488 | for (auto bid = blockIdx.x; bid < m_blocks * n_blocks; bid += gridDim.x) { 489 | auto m = bid / n_blocks; 490 | auto n = bid % n_blocks; 491 | 492 | float acc[16][8]; 493 | memset(acc, 0, sizeof(acc)); 494 | // Mainloop. 495 | for (int k = 0; k < k_blocks; k++) { 496 | // Wait for producer. 497 | wait_barrier(&prod[stage], phase); 498 | 499 | wgmma_fence(); 500 | 501 | #pragma unroll 502 | for (int mma_k = 0; mma_k < BLOCK_K; mma_k += 16) { 503 | wgmma256<1, 1, 1, 0, 0>( 504 | acc, 505 | &smem.A[stage * BLOCK_M * BLOCK_K + mma_k + (wgid - 1) * BLOCK_K * (BLOCK_M / 2)], 506 | &smem.B[stage * BLOCK_N * BLOCK_K + mma_k]); 507 | } 508 | 509 | wgmma_commit_group(); 510 | wgmma_wait_group<0>(); 511 | 512 | // Arrive at consumer. 513 | if (wg_tid == 0) { 514 | arrive_barrier(&cons[stage], 1); 515 | } 516 | stage++; 517 | if (stage == STAGES) { 518 | stage = 0; 519 | phase ^= 1; 520 | } 521 | } 522 | 523 | constexpr int BLOCK_M_WG = BLOCK_M / NUM_CONSUMERS; 524 | constexpr int BLOCK_M_WG_PAD = BLOCK_M / NUM_CONSUMERS + 8; 525 | auto cid = wgid - 1; 526 | auto lane = wg_tid % 32; 527 | auto warp = wg_tid / 32; 528 | bf16* block_sC = smem.C + cid * BLOCK_M_WG_PAD * BLOCK_N; 529 | auto tid_offset = warp * 16 + (lane % 8) * BLOCK_M_WG_PAD; 530 | tid_offset += (lane / 16) * BLOCK_M_WG_PAD * 8 + (lane & 8); 531 | uint32_t base_addr = static_cast(__cvta_generic_to_shared(block_sC)) + tid_offset * sizeof(bf16); 532 | 533 | asm volatile("cp.async.bulk.wait_group 0;"); 534 | 535 | // Write back to gmem. 536 | bf16 acc_bf16[8]; 537 | int* acc_ptr = (int*)acc_bf16; 538 | for (int inst_n = 0; inst_n < INST_N; inst_n += 16) { 539 | uint32_t addr = base_addr + inst_n * BLOCK_M_WG_PAD * sizeof(bf16); 540 | for (int i = 0; i < 8; i++) { 541 | acc_bf16[i] = f2bf(acc[inst_n / 16][i]); 542 | } 543 | asm volatile( 544 | "stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%0], " 545 | "{%1, %2, %3, %4};" 546 | :: "r"(addr), "r"(acc_ptr[0]), "r"(acc_ptr[1]), "r"(acc_ptr[2]), "r"(acc_ptr[3])); 547 | } 548 | //asm volatile("bar.sync %0, 128;" :: "r"(cid + 2) : "memory"); 549 | asm volatile("fence.proxy.async.shared::cta;" ::: "memory"); 550 | 551 | if (wg_tid == 0) { 552 | tma_store(&C, block_sC, m * BLOCK_M + cid * BLOCK_M_WG, n * BLOCK_N); 553 | asm volatile("cp.async.bulk.commit_group;"); 554 | } 555 | } 556 | } 557 | } 558 | } // namespace 559 | 560 | void run_stmatrix_gemm(bf16* A, bf16* B, bf16* C, int M, int N, int K) { 561 | // Compute necessary shared memory for buffers. 562 | size_t smem_size = sizeof(SharedStorage); 563 | check(cudaFuncSetAttribute( 564 | gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 565 | 566 | // Set up TMA descriptors 567 | auto descA = create_tma_desc(A, M, K, BLOCK_M, BLOCK_K, true, false); 568 | auto descB = create_tma_desc(B, N, K, BLOCK_N, BLOCK_K, true, false); 569 | auto descC = create_tma_desc(C, N, M, BLOCK_N, BLOCK_M / NUM_CONSUMERS, false, true); 570 | 571 | // Launch kernel! 572 | gemm<<>>(descA, descB, descC, M, N, K); 573 | } 574 | 575 | void run_stmatrix_gemm(void* A, void* B, void* C, int M, int N, int K) { 576 | run_stmatrix_gemm((bf16*) A, (bf16*)B, (bf16*)C, M, N, K); 577 | } 578 | -------------------------------------------------------------------------------- /pingpong.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace { 10 | 11 | using bf16 = __nv_bfloat16; 12 | 13 | void checkCudaErrors(cudaError_t error, const char* file, int line) { 14 | if (error != cudaSuccess) { 15 | fprintf( 16 | stderr, 17 | "CUDA error at %s:%d: %s\n", 18 | file, 19 | line, 20 | cudaGetErrorString(error)); 21 | exit(EXIT_FAILURE); 22 | } 23 | } 24 | 25 | #define check(err) checkCudaErrors(err, __FILE__, __LINE__) 26 | 27 | __host__ __device__ int cdiv(int m, int n) { 28 | return (m + n - 1) / n; 29 | } 30 | 31 | __device__ inline bf16 f2bf(float v) { 32 | return __float2bfloat16(v); 33 | } 34 | 35 | __host__ static inline CUtensorMap create_tma_desc( 36 | bf16* gmem, 37 | uint32_t M, 38 | uint32_t N, 39 | uint32_t BLOCK_M, 40 | uint32_t BLOCK_N) { 41 | CUtensorMap tma_desc; 42 | assert(BLOCK_N >= 64); 43 | assert(N % 64 == 0); 44 | 45 | uint64_t shape[] = {64, M, N / 64}; 46 | uint64_t stride[] = {sizeof(bf16) * N, 64 * sizeof(bf16)}; 47 | uint32_t box_shape[] = {64, BLOCK_M, BLOCK_N / 64}; 48 | uint32_t box_stride[] = {1, 1, 1}; 49 | 50 | auto result = cuTensorMapEncodeTiled( 51 | &tma_desc, 52 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 53 | 3, 54 | gmem, 55 | shape, 56 | stride, 57 | box_shape, 58 | box_stride, 59 | CU_TENSOR_MAP_INTERLEAVE_NONE, 60 | CU_TENSOR_MAP_SWIZZLE_128B, 61 | CU_TENSOR_MAP_L2_PROMOTION_NONE, 62 | CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 63 | 64 | if (result != CUDA_SUCCESS) { 65 | fprintf(stderr, "TMA desc creation failed\n"); 66 | exit(EXIT_FAILURE); 67 | } 68 | 69 | return tma_desc; 70 | } 71 | 72 | __device__ uint64_t matrix_descriptor_encode(uint64_t x) { 73 | return (x & 0x3ffff) >> 4; 74 | } 75 | 76 | __device__ uint64_t make_smem_desc(bf16* ptr) { 77 | constexpr uint64_t leading_dim_byte_offset = 16; 78 | constexpr uint64_t stride_dim_byte_offset = 1024; 79 | constexpr uint64_t swizzle_128b = 1ull; 80 | uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); 81 | return matrix_descriptor_encode(addr) | 82 | (matrix_descriptor_encode(leading_dim_byte_offset) << 16) | 83 | (matrix_descriptor_encode(stride_dim_byte_offset) << 32) | 84 | (swizzle_128b << 62); 85 | } 86 | 87 | template 88 | __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) { 89 | uint64_t desc_a = make_smem_desc(&sA[0]); 90 | uint64_t desc_b = make_smem_desc(&sB[0]); 91 | // clang-format off 92 | asm volatile( 93 | "{\n" 94 | "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " 95 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 96 | " %8, %9, %10, %11, %12, %13, %14, %15, " 97 | " %16, %17, %18, %19, %20, %21, %22, %23, " 98 | " %24, %25, %26, %27, %28, %29, %30, %31, " 99 | " %32, %33, %34, %35, %36, %37, %38, %39, " 100 | " %40, %41, %42, %43, %44, %45, %46, %47, " 101 | " %48, %49, %50, %51, %52, %53, %54, %55, " 102 | " %56, %57, %58, %59, %60, %61, %62, %63, " 103 | " %64, %65, %66, %67, %68, %69, %70, %71, " 104 | " %72, %73, %74, %75, %76, %77, %78, %79, " 105 | " %80, %81, %82, %83, %84, %85, %86, %87, " 106 | " %88, %89, %90, %91, %92, %93, %94, %95, " 107 | " %96, %97, %98, %99, %100, %101, %102, %103, " 108 | " %104, %105, %106, %107, %108, %109, %110, %111, " 109 | " %112, %113, %114, %115, %116, %117, %118, %119, " 110 | " %120, %121, %122, %123, %124, %125, %126, %127}," 111 | " %128," 112 | " %129," 113 | " %130, %131, %132, %133, %134;\n" 114 | "}\n" 115 | : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), 116 | "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), 117 | "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), 118 | "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), 119 | "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), 120 | "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), 121 | "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), 122 | "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]), 123 | "+f"(d[8][0]), "+f"(d[8][1]), "+f"(d[8][2]), "+f"(d[8][3]), "+f"(d[8][4]), "+f"(d[8][5]), "+f"(d[8][6]), "+f"(d[8][7]), 124 | "+f"(d[9][0]), "+f"(d[9][1]), "+f"(d[9][2]), "+f"(d[9][3]), "+f"(d[9][4]), "+f"(d[9][5]), "+f"(d[9][6]), "+f"(d[9][7]), 125 | "+f"(d[10][0]), "+f"(d[10][1]), "+f"(d[10][2]), "+f"(d[10][3]), "+f"(d[10][4]), "+f"(d[10][5]), "+f"(d[10][6]), "+f"(d[10][7]), 126 | "+f"(d[11][0]), "+f"(d[11][1]), "+f"(d[11][2]), "+f"(d[11][3]), "+f"(d[11][4]), "+f"(d[11][5]), "+f"(d[11][6]), "+f"(d[11][7]), 127 | "+f"(d[12][0]), "+f"(d[12][1]), "+f"(d[12][2]), "+f"(d[12][3]), "+f"(d[12][4]), "+f"(d[12][5]), "+f"(d[12][6]), "+f"(d[12][7]), 128 | "+f"(d[13][0]), "+f"(d[13][1]), "+f"(d[13][2]), "+f"(d[13][3]), "+f"(d[13][4]), "+f"(d[13][5]), "+f"(d[13][6]), "+f"(d[13][7]), 129 | "+f"(d[14][0]), "+f"(d[14][1]), "+f"(d[14][2]), "+f"(d[14][3]), "+f"(d[14][4]), "+f"(d[14][5]), "+f"(d[14][6]), "+f"(d[14][7]), 130 | "+f"(d[15][0]), "+f"(d[15][1]), "+f"(d[15][2]), "+f"(d[15][3]), "+f"(d[15][4]), "+f"(d[15][5]), "+f"(d[15][6]), "+f"(d[15][7]) 131 | : "l"(desc_a), 132 | "l"(desc_b), 133 | "n"(int32_t(ScaleD)), 134 | "n"(int32_t(ScaleA)), 135 | "n"(int32_t(ScaleB)), 136 | "n"(int32_t(TransA)), 137 | "n"(int32_t(TransB))); 138 | // clang-format on 139 | } 140 | 141 | template 142 | __device__ __forceinline__ void wgmma128(float d[8][8], bf16* sA, bf16* sB) { 143 | uint64_t desc_a = make_smem_desc(&sA[0]); 144 | uint64_t desc_b = make_smem_desc(&sB[0]); 145 | // clang-format off 146 | asm volatile( 147 | "{\n" 148 | "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " 149 | "{%0, %1, %2, %3, %4, %5, %6, %7, " 150 | " %8, %9, %10, %11, %12, %13, %14, %15, " 151 | " %16, %17, %18, %19, %20, %21, %22, %23, " 152 | " %24, %25, %26, %27, %28, %29, %30, %31, " 153 | " %32, %33, %34, %35, %36, %37, %38, %39, " 154 | " %40, %41, %42, %43, %44, %45, %46, %47, " 155 | " %48, %49, %50, %51, %52, %53, %54, %55, " 156 | " %56, %57, %58, %59, %60, %61, %62, %63}, " 157 | " %64," 158 | " %65," 159 | " %66, %67, %68, %69, %70;\n" 160 | "}\n" 161 | : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), 162 | "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), 163 | "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), 164 | "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), 165 | "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), 166 | "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), 167 | "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), 168 | "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) 169 | : "l"(desc_a), 170 | "l"(desc_b), 171 | "n"(int32_t(ScaleD)), 172 | "n"(int32_t(ScaleA)), 173 | "n"(int32_t(ScaleB)), 174 | "n"(int32_t(TransA)), 175 | "n"(int32_t(TransB))); 176 | // clang-format on 177 | } 178 | 179 | __device__ void wgmma_commit_group() { 180 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 181 | } 182 | 183 | template 184 | __device__ void wgmma_wait_group() { 185 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); 186 | } 187 | 188 | __device__ void wgmma_fence() { 189 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 190 | } 191 | 192 | template 193 | __device__ static __forceinline__ void setmaxnreg_inc() { 194 | asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(REGS)); 195 | } 196 | 197 | template 198 | __device__ static void __forceinline__ setmaxnreg_dec() { 199 | asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(REGS)); 200 | } 201 | 202 | __device__ static void __forceinline__ 203 | init_barrier(uint64_t* bar, int thread_count, int transaction_count) { 204 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 205 | asm volatile( 206 | "mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(bar_ptr), 207 | "r"(thread_count + transaction_count)); 208 | } 209 | 210 | __device__ static void __forceinline__ wait_barrier(uint64_t* bar, int phase) { 211 | uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); 212 | asm volatile( 213 | "{\n" 214 | ".reg .pred P1;\n" 215 | "LAB_WAIT:\n" 216 | "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" 217 | "@P1 bra.uni DONE;\n" 218 | "bra.uni LAB_WAIT;\n" 219 | "DONE:\n" 220 | "}\n" ::"r"(mbar_ptr), 221 | "r"(phase)); 222 | } 223 | 224 | __device__ static void __forceinline__ 225 | arrive_barrier(uint64_t* bar, int count) { 226 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 227 | asm volatile( 228 | "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" ::"r"( 229 | bar_ptr), 230 | "r"(count) 231 | : "memory"); 232 | } 233 | 234 | __device__ static void __forceinline__ 235 | expect_bytes(uint64_t* bar, uint32_t bytes) { 236 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 237 | asm volatile( 238 | "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" ::"r"(bar_ptr), 239 | "r"(bytes)); 240 | } 241 | 242 | __device__ static void __forceinline__ tma_load( 243 | bf16* dst, 244 | void const* const src_tma_desc, 245 | uint64_t* bar, 246 | int n, 247 | int m) { 248 | uint64_t tma_ptr = reinterpret_cast(src_tma_desc); 249 | uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); 250 | uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); 251 | asm volatile( 252 | "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" 253 | " [%0], [%1, {%3, %4, %5}], [%2];" ::"r"(dst_ptr), 254 | "l"(tma_ptr), 255 | "r"(bar_ptr), 256 | "n"(0), 257 | "r"(m), 258 | "r"(n / 64) 259 | : "memory"); 260 | } 261 | 262 | __device__ static void tma_store( 263 | void const* dst_tma_desc, 264 | bf16* src, 265 | int N, 266 | int M) { 267 | uint64_t tma_ptr = reinterpret_cast(dst_tma_desc); 268 | uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); 269 | asm volatile( 270 | "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group" 271 | " [%0, {%2, %3, %4}], [%1];" ::"l"(tma_ptr), 272 | "r"(src_ptr), 273 | "n"(0), 274 | "r"(M), 275 | "r"(N / 64) 276 | : "memory"); 277 | } 278 | 279 | template 280 | __device__ static void tma_wait_group() { 281 | asm volatile("cp.async.bulk.wait_group %0;" ::"n"(N)); 282 | } 283 | 284 | __device__ static void tma_commit_group() { 285 | asm volatile("cp.async.bulk.commit_group;"); 286 | } 287 | 288 | __device__ static void stmatrix(bf16* smem_ptr, bf16 src[8]) { 289 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); 290 | uint32_t* d = reinterpret_cast(src); 291 | asm volatile( 292 | "stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%0], {%1, %2, %3, %4};" :: 293 | "r"(smem), 294 | "r"(d[0]), 295 | "r"(d[1]), 296 | "r"(d[2]), 297 | "r"(d[3])); 298 | } 299 | 300 | __device__ static void fence_async_proxy() { 301 | asm volatile("fence.proxy.async.shared::cta;"); 302 | } 303 | 304 | __device__ static void __forceinline__ fence_memory(float regs[2][8][8]) { 305 | for (int i = 0; i < 2; i++) { 306 | for (int j = 0; j < 8; j++) { 307 | for (int k = 0; k < 8; k++) { 308 | asm volatile("" : "+f"(regs[i][j][k])::"memory"); 309 | } 310 | } 311 | } 312 | } 313 | 314 | constexpr int BLOCK_M = 128; 315 | constexpr int BLOCK_N = 128; 316 | constexpr int BLOCK_K = 64; 317 | constexpr int NUM_SMS = 132; 318 | constexpr int STAGES = 6; 319 | 320 | constexpr int WG_M = 128; 321 | constexpr int INST_M = 64; 322 | 323 | constexpr int WARPGROUP_SIZE = 128; 324 | constexpr int NUM_CONSUMERS = 2; 325 | constexpr int WARPGROUPS = 1 + NUM_CONSUMERS; 326 | constexpr int NUM_THREADS = WARPGROUPS * WARPGROUP_SIZE; 327 | 328 | struct SharedStorage { 329 | alignas(256) bf16 A[BLOCK_M * BLOCK_K * STAGES]; 330 | alignas(256) bf16 B[BLOCK_K * BLOCK_N * STAGES]; 331 | alignas(256) bf16 C[BLOCK_M * BLOCK_N] __attribute__((aligned(256))); 332 | }; 333 | 334 | __device__ static inline void stage_next(int& stage, int& phase) { 335 | stage++; 336 | if (stage == STAGES) { 337 | stage = 0; 338 | phase ^= 1; 339 | } 340 | } 341 | 342 | __device__ static inline void stage_advance(int& stage, int& phase, int steps) { 343 | phase = phase ^ (((stage + steps) / STAGES) & 1); 344 | stage = (stage + steps) % STAGES; 345 | } 346 | 347 | __global__ __launch_bounds__(NUM_THREADS) void gemm( 348 | const __grid_constant__ CUtensorMap A, 349 | const __grid_constant__ CUtensorMap B, 350 | const __grid_constant__ CUtensorMap C, 351 | int M, 352 | int N, 353 | int K) { 354 | // Producer buffers for A and B. 355 | extern __shared__ __align__(128) uint8_t dynamic_smem[]; 356 | SharedStorage& smem = *reinterpret_cast(dynamic_smem); 357 | 358 | // Barriers. 359 | __shared__ __align__(8) uint64_t prod[STAGES]; 360 | __shared__ __align__(8) uint64_t cons[STAGES]; 361 | __shared__ __align__(8) uint64_t pingpong[2][NUM_CONSUMERS]; 362 | 363 | int tid = threadIdx.x; 364 | int wgid = tid / WARPGROUP_SIZE; 365 | int wg_tid = tid % WARPGROUP_SIZE; 366 | 367 | // Init barriers. 368 | if (tid == 0) { 369 | for (int i = 0; i < STAGES; i++) { 370 | init_barrier(&prod[i], 0, 1); 371 | init_barrier(&cons[i], 0, 1); 372 | } 373 | for (int i = 0; i < NUM_CONSUMERS; i++) { 374 | init_barrier(&pingpong[0][i], 0, 1); 375 | init_barrier(&pingpong[1][i], 0, 1); 376 | } 377 | } 378 | __syncthreads(); 379 | 380 | auto m_blocks = cdiv(M, BLOCK_M); 381 | auto n_blocks = cdiv(N, BLOCK_N); 382 | auto k_blocks = cdiv(K, BLOCK_K); 383 | 384 | if (wgid == 0) { 385 | // Producer warpgroup. 386 | setmaxnreg_dec<40>(); 387 | 388 | if (wg_tid == 0) { 389 | int phase = 0; 390 | int stage = 0; 391 | for (auto bid = blockIdx.x; bid < m_blocks * n_blocks; bid += gridDim.x) { 392 | auto m = (bid / 2) % m_blocks; 393 | auto n = (bid / 2) / m_blocks * 2 + bid % 2; 394 | 395 | for (int k = 0; k < k_blocks; k++) { 396 | // Wait for consumer. 397 | wait_barrier(&cons[stage], phase); 398 | // Set expect bytes for TMA. 399 | expect_bytes( 400 | &prod[stage], 401 | sizeof(bf16) * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)); 402 | // Load A. 403 | tma_load( 404 | &smem.A[stage * BLOCK_K * BLOCK_M], 405 | &A, 406 | &prod[stage], 407 | k * BLOCK_K, 408 | m * BLOCK_M); 409 | // Load B. 410 | tma_load( 411 | &smem.B[stage * BLOCK_K * BLOCK_N], 412 | &B, 413 | &prod[stage], 414 | k * BLOCK_K, 415 | n * BLOCK_N); 416 | stage_next(stage, phase); 417 | } 418 | } 419 | } 420 | } else { 421 | // Consumer warpgroup. 422 | setmaxnreg_inc<232>(); 423 | 424 | int cons_id = wgid - 1; 425 | int stage = 0; 426 | int phase = 0; 427 | int pingpong_phase = 0; 428 | 429 | if (cons_id == 0 && wg_tid == 0) { 430 | for (int i = 0; i < STAGES; i++) { 431 | arrive_barrier(&cons[i], 1); 432 | } 433 | } 434 | 435 | if (cons_id == 1) { 436 | if (wg_tid == 0) { 437 | arrive_barrier(&pingpong[0][1 - cons_id], 1); 438 | arrive_barrier(&pingpong[1][1 - cons_id], 1); 439 | } 440 | stage_advance(stage, phase, k_blocks); 441 | } 442 | 443 | for (auto bid = blockIdx.x + gridDim.x * cons_id; bid < m_blocks * n_blocks; 444 | bid += (gridDim.x * NUM_CONSUMERS)) { 445 | auto m = (bid / 2) % m_blocks; 446 | auto n = (bid / 2) / m_blocks * 2 + bid % 2; 447 | 448 | float acc[WG_M / INST_M][8][8]; 449 | memset(acc, 0, sizeof(acc)); 450 | fence_memory(acc); 451 | 452 | // Mainloop, peeled to fill wgmma_commit_group pipeline. 453 | wait_barrier(&pingpong[0][cons_id], pingpong_phase); 454 | auto prev_stage = stage; 455 | { 456 | // Wait for producer. 457 | wait_barrier(&prod[stage], phase); 458 | wgmma_fence(); 459 | 460 | #pragma unroll 461 | for (int mma_m = 0; mma_m < WG_M / INST_M; mma_m++) { 462 | #pragma unroll 463 | for (int mma_k = 0; mma_k < BLOCK_K; mma_k += 16) { 464 | wgmma128<1, 1, 1, 0, 0>( 465 | acc[mma_m], 466 | &smem 467 | .A[stage * BLOCK_M * BLOCK_K + mma_m * INST_M * BLOCK_K + 468 | mma_k], 469 | &smem.B[stage * BLOCK_N * BLOCK_K + mma_k]); 470 | } 471 | } 472 | 473 | wgmma_commit_group(); 474 | stage_next(stage, phase); 475 | } 476 | // Mainloop. 477 | for (int k = 1; k < k_blocks; k++) { 478 | // Wait for producer. 479 | wait_barrier(&prod[stage], phase); 480 | wgmma_fence(); 481 | 482 | #pragma unroll 483 | for (int mma_m = 0; mma_m < WG_M / INST_M; mma_m++) { 484 | #pragma unroll 485 | for (int mma_k = 0; mma_k < BLOCK_K; mma_k += 16) { 486 | wgmma128<1, 1, 1, 0, 0>( 487 | acc[mma_m], 488 | &smem 489 | .A[stage * BLOCK_M * BLOCK_K + mma_m * INST_M * BLOCK_K + 490 | mma_k], 491 | &smem.B[stage * BLOCK_N * BLOCK_K + mma_k]); 492 | } 493 | } 494 | 495 | wgmma_commit_group(); 496 | wgmma_wait_group<1>(); 497 | 498 | // Arrive at consumer. 499 | if (wg_tid == 0) { 500 | arrive_barrier(&cons[prev_stage], 1); 501 | } 502 | prev_stage = stage; 503 | stage_next(stage, phase); 504 | } 505 | wgmma_wait_group<0>(); 506 | if (wg_tid == 0) { 507 | arrive_barrier(&cons[prev_stage], 1); 508 | } 509 | 510 | // Next k blocks handle by other pingpong consumer. 511 | stage_advance(stage, phase, k_blocks); 512 | 513 | if (wg_tid == 0) { 514 | arrive_barrier(&pingpong[0][1 - cons_id], 1); 515 | } 516 | 517 | // Write back to gmem. 518 | wait_barrier(&pingpong[1][cons_id], pingpong_phase); 519 | 520 | // stmatrix layout is a little mad, but matches the layout of the 8x8 521 | // matrices in 522 | // https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d 523 | // The key to remember here is that the data is already laid out in 524 | // registers the way stmatrix expects it. Your job is just to set the 525 | // address right in each thread; the addresses aren't really related to 526 | // the data in any meaningul way. 527 | auto warp = wg_tid / 32; 528 | auto lane = wg_tid % 32; 529 | auto base_x1_row = warp * 16; 530 | auto base_x4_row = base_x1_row + (lane / 8 % 2) * 8; 531 | auto base_x4_col = lane % 8 + lane / 16 * 8; 532 | auto base_addr = base_x4_row + INST_M * base_x4_col; 533 | 534 | #pragma unroll 535 | for (int mma_m = 0; mma_m < WG_M / INST_M; mma_m++) { 536 | #pragma unroll 537 | for (int inst_n = 0; inst_n < BLOCK_N / 16; inst_n++) { 538 | auto mma_row = mma_m * INST_M * BLOCK_N; 539 | auto regs_col = inst_n * 16 * INST_M; 540 | auto addr = base_addr + mma_row + regs_col; 541 | auto smem_bias = 542 | (static_cast(__cvta_generic_to_shared(smem.C)) & 543 | 0x80) >> 544 | 7; 545 | auto lane_swizzle = ((lane + smem_bias) & 0x7) << 3; 546 | addr = addr ^ lane_swizzle; 547 | bf16 acc_bf16[8]; 548 | for (int i = 0; i < 8; i++) { 549 | acc_bf16[i] = f2bf(acc[mma_m][inst_n][i]); 550 | } 551 | stmatrix(&smem.C[addr], acc_bf16); 552 | } 553 | fence_async_proxy(); 554 | if (wg_tid == 0) { 555 | tma_store( 556 | &C, 557 | &smem.C[mma_m * INST_M * BLOCK_N], 558 | m * BLOCK_M + mma_m * INST_M, 559 | n * BLOCK_N); 560 | tma_commit_group(); 561 | } 562 | } 563 | 564 | tma_wait_group<0>(); 565 | if (wg_tid == 0) { 566 | arrive_barrier(&pingpong[1][1 - cons_id], 1); 567 | } 568 | pingpong_phase ^= 1; 569 | } 570 | } 571 | } 572 | 573 | } // namespace 574 | 575 | void run_pingpong(bf16* A, bf16* B, bf16* C, int M, int N, int K) { 576 | // Compute necessary shared memory for buffers. 577 | size_t smem_size = sizeof(SharedStorage); 578 | check(cudaFuncSetAttribute( 579 | gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); 580 | 581 | // Set up TMA descriptors 582 | auto descA = create_tma_desc(A, M, K, BLOCK_M, BLOCK_K); 583 | auto descB = create_tma_desc(B, N, K, BLOCK_N, BLOCK_K); 584 | auto descC = create_tma_desc(C, N, M, BLOCK_N, INST_M); 585 | 586 | // Launch kernel! 587 | gemm<<>>(descA, descB, descC, M, N, K); 588 | } 589 | 590 | void run_pingpong(void* A, void* B, void* C, int M, int N, int K) { 591 | run_pingpong((bf16*)A, (bf16*)B, (bf16*)C, M, N, K); 592 | } 593 | --------------------------------------------------------------------------------