├── lesson_1 └── readme.md ├── lesson_5 └── triton_contiguous_group_gemm │ ├── readme.md │ ├── benchmark_grouped_gemm.py │ └── grouped_gemm.py ├── lesson_2_first_kernel └── readme.md ├── readme.md └── kernels ├── vector_addition_tutorial.py └── matmul_outer_k.py /lesson_1/readme.md: -------------------------------------------------------------------------------- 1 | Let's establish some core fundamentals for GPU kernel programming. 2 | What is Triton and what are the basics of GPU memory heirarchy. 3 | 4 | A starting video (todo - probably need to update this): 5 | https://www.youtube.com/watch?v=s1ILGG0TyYM 6 | -------------------------------------------------------------------------------- /lesson_5/triton_contiguous_group_gemm/readme.md: -------------------------------------------------------------------------------- 1 | Group General Matrix-Multiply (GEMM) is a key operation in Mixture-of-Experts (MoE) models. 2 | 3 | The Group GEMM kernel batches the execution of multiple independant GEMM problems in a single fused kernel. Each group can have different shapes. 4 | 5 | This tutorial will walk through how to write and optimize a Group GEMM kernel in Triton. 6 | 7 | (TODO - make a video) -------------------------------------------------------------------------------- /lesson_2_first_kernel/readme.md: -------------------------------------------------------------------------------- 1 | Let's get comfortable with the some kernel basics and write our first kernel. 2 | 3 | We'll solve the problem of having two arrays (python lists) and we want to add them together. 4 | 5 | First up - we need to mentally move from sequential programming (i.e. looping) and move to leverage the core advantage of gpus, massive parallelization. 6 | 7 | Thus: 8 | making the shift to parallel programming: 9 | https://youtu.be/MEZ7XhzTLEg 10 | 11 | writing your first Triton kernel: 12 | https://youtu.be/8P0M-DXr774 13 | 14 | verifying your kernel (always important, numerical fidelity even though we have parallelized the work): 15 | https://youtu.be/kEGW0SemWWw 16 | 17 | and finally, let's compare our Triton kernel vs Pytorch in terms of speed: 18 | https://youtu.be/Nh5QIkGuExQ 19 | 20 | A self contained kernel, wrapper, and test case driver is here with extensive code comments. (TODO - make a video walkthrough): 21 | [Vector addition kernel](https://github.com/gpu-mode/triton-tutorials/blob/main/kernels/vector_addition_tutorial.py) 22 | 23 | 24 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Work in Progress! 2 | 3 | General summary: 4 | work through progressive kernels to learn Triton with: 5 | a - [vector add](https://github.com/gpu-mode/triton-tutorials/blob/main/kernels/vector_addition_tutorial.py) 6 | b - [simple matmul](https://github.com/gpu-mode/triton-tutorials/blob/main/kernels/matmul_outer_k.py) (outer k loop and tiled) 7 | c - fused softmax (in pytorch, then in Triton) 8 | d - softmax with backward 9 | e - flash attention 2 10 | f - group gemm with backwards (bf16, then fp8) 11 | g - MoE permute / unpermute kernels 12 | 13 | ## Getting Started: 14 | Lessons are arranged in order, starting with lesson_1 etc. 15 | 16 | Some of the above are also examples in the core Triton repo...what's the difference? 17 | 18 | Generically and from an optionated view, the examples in the core repo tend to put too much, too soon into the examples obfuscating the core lessons one should take from it (again, biased opinion). 19 | 20 | We'll stick with simpler versions that allow the core tenets of kernel programming to shine through, bypassing autotuning, certain cache optimizations, and other aspects and pull them in with later lessons. 21 | In addition. goal is to have almost every line in each kernel commented with clear, concise comments to make it easy to map what the kernel code is doing. 22 | 23 | External contributions are very welcome - this is meant to be an evolving, iteratively improving tutorial series. 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /lesson_5/triton_contiguous_group_gemm/benchmark_grouped_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | import torch 9 | import triton 10 | import triton.language as tl 11 | from typing import Tuple 12 | from grouped_gemm import grouped_gemm_persistent 13 | 14 | 15 | def construct_grouped_gemm( 16 | M: int, 17 | K: int, 18 | N: int, 19 | num_experts: int, 20 | group_size_m: int = 128, 21 | device: str = "cuda", 22 | dtype: torch.dtype = torch.bfloat16, 23 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 24 | """ 25 | Create test data with proper block alignment. 26 | 27 | Args: 28 | batch_size: Batch size 29 | seq_len: Sequence length 30 | hidden_dim: Hidden dimension (K) 31 | output_dim: Output dimension (N) 32 | num_experts: Number of experts 33 | group_size_m: Size of expert groups 34 | device: Device to create tensors on 35 | dtype: Data type for inputs and weights 36 | 37 | Returns: 38 | Tuple of (inputs, expert_weights, expert_indices) 39 | """ 40 | # Calculate total number of tokens 41 | M_total = M * num_experts 42 | 43 | # Ensure M_total is a multiple of group_size_m 44 | padded_M = ((M_total + group_size_m - 1) // group_size_m) * group_size_m 45 | padding_needed = padded_M - M_total 46 | 47 | if padding_needed > 0: 48 | print(f"Padding input from {M_total} to {padded_M} to ensure group alignment") 49 | M_total = padded_M 50 | 51 | # Create inputs 52 | inputs = torch.randn((M_total, K), dtype=dtype, device=device) 53 | 54 | # Create expert weights 55 | expert_weights = torch.randn( 56 | (num_experts, N, K), dtype=dtype, device=device 57 | ) 58 | 59 | # Create expert indices with proper group alignment 60 | expert_indices = torch.zeros(M_total, dtype=torch.int32, device=device) 61 | 62 | # Assign experts in contiguous blocks of group_size_m 63 | num_groups = M_total // group_size_m 64 | 65 | for group_idx in range(num_groups): 66 | start_idx = group_idx * group_size_m 67 | end_idx = start_idx + group_size_m 68 | 69 | # Assign this entire group to one expert 70 | expert_idx = group_idx % num_experts 71 | expert_indices[start_idx:end_idx] = expert_idx 72 | 73 | return inputs, expert_weights, expert_indices 74 | 75 | 76 | def pytorch_reference_gemm( 77 | inputs: torch.Tensor, 78 | expert_weights: torch.Tensor, 79 | expert_indices: torch.Tensor, 80 | group_size_m: int = 128, 81 | ) -> torch.Tensor: 82 | """ 83 | Reference implementation using PyTorch for verification. 84 | """ 85 | M_total, K = inputs.shape 86 | num_experts, N, _ = expert_weights.shape 87 | 88 | output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype) 89 | 90 | # Process each group 91 | for i in range(0, M_total, group_size_m): 92 | end_idx = min(i + group_size_m, M_total) 93 | 94 | # Get expert index for this group 95 | expert_idx = expert_indices[i].item() 96 | 97 | # Get expert weights 98 | expert_weight = expert_weights[expert_idx] 99 | 100 | # Compute output for this group 101 | output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.t()) 102 | 103 | return output 104 | 105 | # PyTorch Benchmarking 106 | import torch.utils.benchmark as benchmark 107 | 108 | 109 | def triton_gemm_func(a, b, expert_indices): 110 | return grouped_gemm_persistent(a, b, expert_indices) 111 | 112 | def gemm_func_torch(a, b, expert_indices): 113 | return pytorch_reference_gemm(a, b, expert_indices) 114 | 115 | num_threads = torch.get_num_threads() 116 | print(f'Benchmarking on {num_threads} threads') 117 | results = [] 118 | 119 | for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): 120 | 121 | a, b, expert_indices = construct_grouped_gemm(m, k, n, num_groups) 122 | 123 | label = 'BF16 Grouped GEMM Performance' 124 | sub_label = f'num_groups: {num_groups}, m: {m}, n: {n}, k: {k}' 125 | 126 | results.append(benchmark.Timer( 127 | stmt='triton_gemm_func(a, b, expert_indices)', 128 | setup='from __main__ import triton_gemm_func', 129 | globals={'a': a, 'b' : b, 'expert_indices': expert_indices}, 130 | num_threads=num_threads, 131 | label=label, 132 | sub_label=sub_label, 133 | description='Triton Group GEMM').blocked_autorange(min_run_time=1)) 134 | 135 | results.append(benchmark.Timer( 136 | stmt='gemm_func_torch(a, b, m_offsets)', 137 | setup='from __main__ import gemm_func_torch', 138 | globals={'a': a, 'b' : b, 'expert_indices': expert_indices}, 139 | num_threads=num_threads, 140 | label=label, 141 | sub_label=sub_label, 142 | description='PyTorch Reference Group GEMM').blocked_autorange(min_run_time=1)) 143 | 144 | 145 | compare = benchmark.Compare(results) 146 | compare.print() -------------------------------------------------------------------------------- /kernels/vector_addition_tutorial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | 4 | import triton.language as tl # gives access to the Triton language components 5 | # see https://github.com/triton-lang/triton/blob/main/python/triton/language/__init__.py for all tl components 6 | 7 | # we need to create a min of two components: the kernel and the kernel wrapper that both launches the kernel and interfaces with PyTorch 8 | # the kernel is the actual Triton program that we want to run on the GPU 9 | 10 | # @triton.jit is a decorator that designates a function as a Triton kernel 11 | @triton.jit 12 | def vector_addition_kernel( 13 | 14 | # the first two arguments are pointers to the input and output tensors 15 | # we will see how to pass these arguments in the wrapper function below 16 | x_ptr, 17 | y_ptr, 18 | output_ptr, 19 | 20 | # the total count of the inputs we are summing over 21 | n_elements, 22 | 23 | # the stride of the input and output tensors 24 | # tl.constexpr is a decorator that designates a variable as a compile-time constant 25 | # this allows Triton to perform compile-time optimizations b/c it knows 26 | # (and you guarantee via this decorator) the exact value of the variable at compile time 27 | # allowing the compiler to optimize for that value 28 | 29 | x_stride: tl.constexpr, 30 | y_stride: tl.constexpr, 31 | output_stride: tl.constexpr, 32 | 33 | # the block size for the grid and the multiplier in mapping of program ID to offsets 34 | BLOCK_SIZE: tl.constexpr, 35 | 36 | ): 37 | # the program ID is the unique identifier for each threadblock (note block!) in the grid 38 | # each PID will handle 'BLOCK_SIZE' elements 39 | pid = tl.program_id(axis=0) # the grid is 1D so axis=0 means we move along the x-axis or columns in this case. 40 | block_start = pid * BLOCK_SIZE # example: assuming pid==2, BLOCK_SIZE==128, then our current program will start at block_start==256 41 | 42 | # tl.arange(0, BLOCK_SIZE) will give us a range of 0 to 127, added to the block_start, 43 | # so offsets will be 256 to 383 44 | # note: tl.arange only generates a range of numbers that are power of 2 sizes, so we need to use the mask to 45 | # handle cases where the input is not a power of 2 size (i.e. 257 elements) 46 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 47 | 48 | # the offsets for the input tensors 49 | x_offsets = offsets * x_stride # stride is 1, so x_offsets will be 256 to 383 50 | y_offsets = offsets * y_stride # stride is 1, so y_offsets will be 256 to 383 51 | 52 | # the mask for elements outside the range of [0, n_elements) 53 | mask = offsets < n_elements # we need to handle cases where the input is not a power of 2 dimension (i.e. 257 elements) 54 | 55 | # load the input tensors 56 | x = tl.load(x_ptr + x_offsets, mask=mask) 57 | y = tl.load(y_ptr + y_offsets, mask=mask) 58 | 59 | # compute the output 60 | output = x + y 61 | 62 | # store the output 63 | tl.store(output_ptr + x_offsets, output, mask=mask) # stride is 1 64 | 65 | 66 | # the kernel wrapper is a regular Python function that interfaces with PyTorch 67 | # it is responsible for verifying the inputs, creating output buffers for results, and launching the kernel 68 | # with appropriate grid size and information such as input strides 69 | def vector_addition(x, y): 70 | 71 | # lets' first make sure that the input tensors are on the GPU 72 | assert x.is_cuda and y.is_cuda, "Input tensors must be on GPU!" 73 | # we need to also make sure that a and b are the same size 74 | assert x.numel() == y.numel(), "Input tensors must be the same size!" 75 | 76 | 77 | # the output shape is the same as the input shape 78 | output = torch.empty_like(x) 79 | 80 | # the grid is the number of program blocks we need to launch to handle the entire input 81 | # cdiv just means ceiling division, so it returns the smallest integer >= x / y ensuring we always 82 | # launch enough program blocks to handle the entire input 83 | grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) 84 | 85 | # launch the kernel, note the grid is a lambda function that takes in the meta-arguments 86 | vector_addition_kernel[grid]( 87 | # pass the pointers to the input and output tensors 88 | x_ptr=x, 89 | y_ptr=y, 90 | output_ptr=output, 91 | 92 | # pass the total count of the input we are summing 93 | n_elements=x.numel(), 94 | 95 | # pass the strides (how many data element size jumps to get to the next element) of the input and output tensors 96 | # in this case, everything is contiguous, so we pass 1 97 | x_stride=1, 98 | y_stride=1, 99 | output_stride=1, 100 | 101 | # pass the block size for the grid and the mapping of program ID to offsets 102 | BLOCK_SIZE=128, 103 | 104 | ) 105 | 106 | # return the output tensor 107 | return output 108 | 109 | 110 | if __name__ == "__main__": 111 | # two tests - one for power of 2 size and one for non power of 2 size 112 | # the non power of 2 size is to test the mask functionality 113 | # create a random, power of 2 size, input tensor 114 | 115 | # Test 1: power of 2 size (1024!) 116 | 117 | x = torch.randn(1024, device='cuda') 118 | y = torch.randn(1024, device='cuda') 119 | 120 | # we can then use the kernel wrapper to perform a vector addition 121 | 122 | output = vector_addition(x, y) 123 | 124 | # verify the result with PyTorch reference implementation 125 | output_ref = x + y 126 | assert torch.allclose(output, output_ref) 127 | print("Success with power of 2 size (1024!") 128 | 129 | print(f"{output=}") 130 | 131 | # Test 2: non power of 2 size (257!) 132 | # create a non power of 2 input tensor 133 | x = torch.randn(257, device='cuda') 134 | y = torch.randn(257, device='cuda') 135 | 136 | # we can now use the kernel wrapper to perform a vector addition 137 | 138 | output_np2 = vector_addition(x, y) 139 | 140 | # verify the result with PyTorch reference implementation 141 | output_ref_np2 = x + y 142 | assert torch.allclose(output_np2, output_ref_np2) 143 | print("Success with non power of 2 size (num_elems = 257!)") 144 | 145 | print(f"{output_np2[0:5]=}") 146 | -------------------------------------------------------------------------------- /kernels/matmul_outer_k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def outerk_matmul_kernel( 7 | # input pointers 8 | a_ptr, 9 | b_ptr, 10 | # output ptr 11 | c_ptr, 12 | # matrix dimensions 13 | M, N, K, 14 | # the stride variables represent how much to increase the ptr by when moving by 1 15 | # element in a particular dimension. E.g. stride_am is how much to increase a_ptr 16 | # by to get the element one row down (A has M rows) 17 | stride_am, stride_ak, 18 | stride_bk, stride_bn, # b is transposed, so k is now row dimension 19 | stride_cm, stride_cn, 20 | # meta-parameters 21 | BLOCK_SIZE_M: tl.constexpr, 22 | BLOCK_SIZE_N: tl.constexpr, 23 | BLOCK_SIZE_K: tl.constexpr, 24 | ): 25 | """ 26 | Compute the matrix multiplication C = A @ B 27 | 28 | A is of shape (M, K) 29 | B is of shape (K, N) # note that B is transposed to achieve this 30 | C is of shape (M, N) 31 | """ 32 | 33 | # map program ids to blocks in the matrices 34 | pid_m = tl.program_id(axis=0) # row for A and C 35 | pid_n = tl.program_id(axis=1) # col for B and C 36 | 37 | # calculate our starting position for this block 38 | start_m = pid_m * BLOCK_SIZE_M 39 | start_n = pid_n * BLOCK_SIZE_N 40 | 41 | # create offsets for accessing elements within the block 42 | offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) 43 | offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) 44 | 45 | # masking to handle non-multiple of block size cases 46 | mask_m = offsets_m < M 47 | mask_n = offsets_n < N 48 | 49 | # init our accumulator 50 | # it is the size of our C output block (BLOCK_SIZE_M, BLOCK_SIZE_N) 51 | acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 52 | 53 | # now we iterate over the K dimension in blocks of BLOCK_SIZE_K 54 | for k in range(0, K, BLOCK_SIZE_K): 55 | # create offsets for the k dimension 56 | offsets_k = k + tl.arange(0, BLOCK_SIZE_K) 57 | 58 | # mask for K dimension 59 | mask_k = offsets_k < K 60 | 61 | # compute memory addresses for A and B blocks 62 | # note that we are using a column vector of M or K dimension offsets 63 | # and a row vector of K or N dimension offsets to create a 2D grid of all offsets for ptrs 64 | a_ptrs = a_ptr + (offsets_m[:, None] * stride_am + offsets_k[None, :] * stride_ak) 65 | b_ptrs = b_ptr + (offsets_k[:, None] * stride_bk + offsets_n[None, :] * stride_bn) 66 | 67 | # load A and B blocks using our K mask. We set other=0.0 to fill remaining spots 68 | a = tl.load(a_ptrs, mask= mask_m[:, None] & mask_k[None, :], other=0.0) 69 | b = tl.load(b_ptrs, mask = mask_k[:, None] & mask_n[None, :], other=0.0) 70 | 71 | # perform our current k block matmul and add it to the accumulator for C 72 | acc += tl.dot(a,b) 73 | 74 | # store result back to global memory 75 | c_ptrs = c_ptr + (offsets_m[:, None] * stride_cm + offsets_n[None, :] * stride_cn) 76 | tl.store(c_ptrs, acc, mask= mask_m[:,None] & mask_n[None, :]) 77 | 78 | 79 | 80 | # our triton kernel wrapper/interface function 81 | def triton_outer_k_matmul(a, b): 82 | """ 83 | Compute matmul of C = A @ B using Triton block tiled kernel 84 | 85 | Inputs: 86 | a: torch.tensor of shape (M, K) 87 | b: torch.tensor of shape (K, N) 88 | 89 | Returns: 90 | C: shape (M, N) 91 | 92 | """ 93 | 94 | # verify our inputs 95 | assert a.is_cuda and b.is_cuda, "a and b must be on GPU" 96 | assert a.shape[1] == b.shape[0], "mismatch between inner dimensions" 97 | 98 | M = a.shape[0] 99 | K = a.shape[1] 100 | N = b.shape[1] 101 | 102 | # allocate our output C tensor 103 | c = torch.empty((M, N), device = a.device, dtype = torch.float32) 104 | 105 | # calculate the strides for our kernel 106 | stride_am, stride_ak = a.stride(0), a.stride(1) 107 | stride_bk, stride_bn = b.stride(0), b.stride(1) 108 | stride_cm, stride_cn = c.stride(0), c.stride(1) 109 | 110 | # define block sizes we will use to process the matmul 111 | # Note - we will tune this later with autotune, but for now we can hand tune 112 | BLOCK_SIZE_M = 64 113 | BLOCK_SIZE_N = 64 114 | BLOCK_SIZE_K = 64 115 | 116 | # calculate our grid 117 | grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)) 118 | 119 | # launch our kernel 120 | outerk_matmul_kernel[grid] ( 121 | a, b, c, 122 | M, N, K, 123 | stride_am, stride_ak, 124 | stride_bk, stride_bn, 125 | stride_cm, stride_cn, 126 | BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 127 | ) 128 | 129 | return c 130 | 131 | 132 | # Example usage and performance comparison 133 | def benchmark_matmul(): 134 | # Create random matrices 135 | M, N, K = 8192, 8192, 4096 136 | a = torch.randn((M, K), device='cuda', dtype=torch.float32) 137 | b = torch.randn((K, N), device='cuda', dtype=torch.float32) 138 | 139 | 140 | 141 | # Verify correctness 142 | torch_output = torch.matmul(a, b) 143 | triton_output = triton_outer_k_matmul(a, b) 144 | assert torch.allclose(torch_output, triton_output, rtol=1e-2, atol=1e-1), \ 145 | "Triton and PyTorch matmul results don't match!" 146 | 147 | 148 | 149 | # Benchmark PyTorch matmul 150 | torch.cuda.synchronize() 151 | start = torch.cuda.Event(enable_timing=True) 152 | end = torch.cuda.Event(enable_timing=True) 153 | 154 | torch.matmul(a, b) # warmup 155 | start.record() 156 | for _ in range(10): 157 | torch.matmul(a, b) 158 | end.record() 159 | torch.cuda.synchronize() 160 | pytorch_time = start.elapsed_time(end) / 10 161 | 162 | # Benchmark Triton matmul 163 | triton_outer_k_matmul(a, b) # warmup 164 | torch.cuda.synchronize() 165 | start.record() 166 | for _ in range(10): 167 | triton_outer_k_matmul(a, b) 168 | end.record() 169 | torch.cuda.synchronize() 170 | triton_time = start.elapsed_time(end) / 10 171 | 172 | print(f"PyTorch matmul time: {pytorch_time:.2f} ms") 173 | print(f"Triton matmul time: {triton_time:.2f} ms") 174 | print(f"Speedup: {pytorch_time / triton_time:.2f}x") 175 | 176 | 177 | if __name__ == "__main__": 178 | benchmark_matmul() 179 | -------------------------------------------------------------------------------- /lesson_5/triton_contiguous_group_gemm/grouped_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Credit: Less Wright (Meta) and Adnan Hoque (IBM) 8 | 9 | import triton 10 | import triton.language as tl 11 | import torch 12 | 13 | @triton.jit 14 | def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m): 15 | group_id = tile_id // num_pid_in_group 16 | first_pid_m = group_id * super_group_m 17 | group_size_m = min(num_pid_m - first_pid_m, super_group_m) 18 | pid_m = first_pid_m + (tile_id % group_size_m) 19 | pid_n = (tile_id % num_pid_in_group) // group_size_m 20 | return pid_m, pid_n 21 | 22 | 23 | @triton.jit 24 | def _kernel_grouped_gemm_persistent_bf16( 25 | # Pointers to matrices 26 | a_ptr, 27 | b_ptr, 28 | c_ptr, 29 | # Pointer to indices array 30 | indices_ptr, 31 | # Matrix dimensions 32 | M_TOTAL: tl.constexpr, # Total M dimension (sum of all groups) 33 | N: tl.constexpr, # N dimension 34 | K: tl.constexpr, # K dimension 35 | # Number of experts 36 | NUM_EXPERTS: tl.constexpr, 37 | # Tiling parameters 38 | BLOCK_SIZE_M: tl.constexpr, 39 | BLOCK_SIZE_N: tl.constexpr, 40 | BLOCK_SIZE_K: tl.constexpr, 41 | NUM_SMS: tl.constexpr, 42 | # NUM_CONSUMER_GROUPS: tl.constexpr, 43 | # Group size (for aligned loads) 44 | GROUP_SIZE_M: tl.constexpr = 128, 45 | SUPER_GROUP_M: tl.constexpr = 32, # 32 works best 46 | ): 47 | """ 48 | Contiguous Grouped GEMM kernel forward. 49 | IMPORTANT: Assumes GROUP_SIZE_M is a multiple of BLOCK_SIZE_M or vice versa, 50 | and all inputs are pre-aligned to these block boundaries. 51 | """ 52 | 53 | c_type = c_ptr.dtype.element_ty 54 | 55 | start_pid = tl.program_id(axis=0) 56 | num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M) 57 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 58 | k_tiles = tl.cdiv(K, BLOCK_SIZE_K) 59 | num_tiles = num_pid_m * num_pid_n 60 | tile_id_c = start_pid - NUM_SMS 61 | num_pid_in_group = SUPER_GROUP_M * num_pid_n 62 | 63 | for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): 64 | 65 | tile_m_idx, tile_n_idx = _compute_pid(tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M) 66 | 67 | # starting indices for this tile 68 | m_start = tile_m_idx * BLOCK_SIZE_M 69 | n_start = tile_n_idx * BLOCK_SIZE_N 70 | 71 | # Only process if in bounds 72 | if m_start < M_TOTAL: 73 | 74 | offs_m = m_start + tl.arange(0, BLOCK_SIZE_M) 75 | offs_n = n_start + tl.arange(0, BLOCK_SIZE_N) 76 | 77 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 78 | for ki in range(k_tiles): 79 | 80 | # Offsets for K dim 81 | offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 82 | 83 | # Create masks for bounds checking 84 | mask_m = offs_m < M_TOTAL 85 | mask_n = offs_n < N 86 | mask_k = offs_k < K 87 | 88 | # masks for A and B 89 | mask_a = mask_m[:, None] & mask_k[None, :] 90 | mask_b = mask_n[:, None] & mask_k[None, :] 91 | 92 | # Determine the expert group index and load expert ID 93 | group_idx = m_start // GROUP_SIZE_M 94 | expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M) 95 | 96 | # Load inputs (A) with bounds checking 97 | a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] 98 | a = tl.load(a_ptrs, mask=mask_a, other=0.0) 99 | 100 | # Load expert weights (B) for the expert assigned to this block 101 | b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :] 102 | b = tl.load(b_ptrs, mask=mask_b, other=0.0) 103 | 104 | # Accumulate matrix multiplication for this K tile 105 | accumulator += tl.dot(a, b.T) 106 | 107 | tile_id_c += NUM_SMS 108 | tile_m_idx, tile_n_idx = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M) 109 | 110 | offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 111 | offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 112 | 113 | # Create masks for bounds checking 114 | mask_m = offs_m < M_TOTAL 115 | mask_n = offs_n < N 116 | mask_c = mask_m[:, None] & mask_n[None, :] 117 | 118 | c = accumulator.to(tl.float32) 119 | 120 | # Store output (C) with bounds checking 121 | c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] 122 | tl.store(c_ptrs, c.to(c_type), mask=mask_c) 123 | 124 | # =============== Wrapper for Grouped GEMM ================= 125 | def _grouped_gemm_persistent( 126 | inputs: torch.Tensor, # [M_total, K] 127 | expert_weights: torch.Tensor, # [num_experts, N, K] 128 | expert_indices: torch.Tensor, # [M_total] 129 | group_size_m: int = 128, 130 | ) -> torch.Tensor: 131 | """ 132 | contiguous grouped GEMM forward pass for MoE. 133 | All tokens mapped to the same expert must be in contiguous blocks of size group_size_m. 134 | 135 | Args: 136 | inputs: Input tensor of shape [M_total, K] 137 | expert_weights: Expert weight tensor of shape [num_experts, N, K] 138 | expert_indices: Indices tensor of shape [M_total] mapping each token to its expert 139 | group_size_m: Size of contiguous token blocks for each expert (default: 128) 140 | x_scale: Input tensor scales of shape [M_total, 1] 141 | w_scale: Expert weight tensor scales of shape [num_experts, N] 142 | Returns: 143 | Output tensor of shape [M_total, N] 144 | """ 145 | # Validate inputs 146 | assert inputs.is_contiguous(), "Input tensor must be contiguous" 147 | assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous" 148 | assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous" 149 | 150 | 151 | # Check if inputs are properly aligned 152 | M_total, K = inputs.shape 153 | assert ( 154 | M_total % group_size_m == 0 155 | ), f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})" 156 | 157 | # Convert expert_indices to int32 if needed 158 | if expert_indices.dtype != torch.int32: 159 | expert_indices = expert_indices.to(torch.int32) 160 | 161 | # Get dimensions 162 | num_experts, N, K_weights = expert_weights.shape 163 | 164 | # Validate dimensions 165 | assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})" 166 | assert ( 167 | expert_indices.shape[0] == M_total 168 | ), "Expert indices length must match M_total" 169 | 170 | # Create output tensor 171 | output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16) 172 | 173 | # Calculate grid size for the kernel 174 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count 175 | grid = (NUM_SMS, 1, 1) 176 | 177 | # SET BLOCK_SIZES 178 | BLOCK_M = 64 179 | BLOCK_N = 64 180 | BLOCK_K = 64 181 | 182 | # Launch kernel 183 | _kernel_grouped_gemm_persistent_bf16[grid]( 184 | inputs, 185 | expert_weights, 186 | output, 187 | expert_indices, 188 | M_TOTAL=M_total, 189 | N=N, 190 | K=K, 191 | NUM_EXPERTS=num_experts, 192 | GROUP_SIZE_M=group_size_m, 193 | NUM_SMS=NUM_SMS, 194 | BLOCK_SIZE_M=BLOCK_M, 195 | BLOCK_SIZE_N=BLOCK_N, 196 | BLOCK_SIZE_K=BLOCK_K, 197 | num_warps=4, 198 | num_stages=3, 199 | ) 200 | return output 201 | 202 | 203 | def grouped_gemm_persistent( 204 | inputs: torch.Tensor, # [M_total, K] 205 | expert_weights: torch.Tensor, # [num_experts, N, K] 206 | expert_indices: torch.Tensor, # [M_total] 207 | ) -> torch.Tensor: 208 | return _grouped_gemm_persistent(inputs, expert_weights, expert_indices) --------------------------------------------------------------------------------