├── 01-vector-addition ├── CUDA │ └── native.cu ├── MOJO │ └── native.mojo └── Triton │ └── native.py ├── 02-matrix-multiplication ├── CUDA │ └── native.cu ├── MOJO │ └── native.mojo └── Triton │ ├── native.py │ ├── use_tma.py │ ├── with_dot_v1.py │ ├── with_dot_v2.py │ └── with_dot_v3.py ├── 03-matrix-transpose ├── CUDA │ ├── native.cu │ └── use_shared.cu ├── MOJO │ └── native.mojo └── Triton │ └── native.py ├── 13-softmax ├── CUDA │ └── native.cu ├── MOJO │ └── native.mojo └── Triton │ ├── reduce_in_one_block.py │ └── three_kernel.py ├── README.md └── utils └── createFolder.py /01-vector-addition/CUDA/native.cu: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | #include 3 | 4 | __global__ void vector_add(const float* A, const float* B, float* C, int N) { 5 | int i = blockDim.x * blockIdx.x + threadIdx.x; 6 | if (i < N) { 7 | C[i] = A[i] + B[i]; 8 | } 9 | } 10 | 11 | // A, B, C are device pointers (i.e. pointers to memory on the GPU) 12 | void solve(const float* A, const float* B, float* C, int N) { 13 | // 128/256 is the optimal parameter 14 | int threadsPerBlock = 128; 15 | int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock; 16 | 17 | vector_add<<>>(A, B, C, N); 18 | cudaDeviceSynchronize(); 19 | } 20 | -------------------------------------------------------------------------------- /01-vector-addition/MOJO/native.mojo: -------------------------------------------------------------------------------- 1 | from gpu.host import DeviceContext 2 | from gpu.id import block_dim, block_idx, thread_idx 3 | from memory import UnsafePointer 4 | from math import ceildiv 5 | 6 | @parameter 7 | fn vector_add_kernel(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], N: Int32): 8 | var idx = block_dim.x * block_idx.x + thread_idx.x 9 | if Int32(idx) < N: 10 | C[idx] = A[idx] + B[idx] 11 | 12 | # A, B, C are device pointers (i.e. pointers to memory on the GPU) 13 | @export 14 | def solve(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], N: Int32): 15 | var BLOCK_SIZE: Int32 = 128 16 | var ctx = DeviceContext() 17 | var num_blocks = ceildiv(N, BLOCK_SIZE) 18 | 19 | ctx.enqueue_function[vector_add_kernel]( 20 | A, B, C, N, 21 | grid_dim = num_blocks, 22 | block_dim = BLOCK_SIZE 23 | ) 24 | 25 | ctx.synchronize() 26 | -------------------------------------------------------------------------------- /01-vector-addition/Triton/native.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 7 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 8 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 9 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 10 | # Multiple "programs" are processing different data chunks. Here we determine which one we are 11 | pid = tl.program_id(axis=0) # We launch a 1D grid, so the axis is 0. 12 | # This program will handle input starting from a certain offset. 13 | # For example, if the vector length is 256 and block size is 64, 14 | # programs will access elements [0:64), [64:128), [128:192), [192:256) respectively. 15 | block_start = pid * BLOCK_SIZE 16 | # Note: `offsets` is a list of pointers. 17 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 18 | # Create a mask to prevent out-of-bounds memory access 19 | mask = offsets < n_elements 20 | # Load a and b from DRAM; the mask ensures we avoid reading beyond the input size 21 | a = tl.load(a_ptr + offsets, mask=mask) 22 | b = tl.load(b_ptr + offsets, mask=mask) 23 | c = a + b 24 | # Write a + b back to DRAM 25 | tl.store(c_ptr + offsets, c, mask=mask) 26 | 27 | # a_ptr, b_ptr, c_ptr are raw device pointers 28 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, N: int): 29 | # 128/256 is the optimal parameter 30 | BLOCK_SIZE = 128 31 | grid = (triton.cdiv(N, BLOCK_SIZE),) 32 | vector_add_kernel[grid](a_ptr, b_ptr, c_ptr, N, BLOCK_SIZE) 33 | -------------------------------------------------------------------------------- /02-matrix-multiplication/CUDA/native.cu: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | #include 3 | 4 | __global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) { 5 | int row = blockIdx.y * blockDim.y + threadIdx.y; 6 | int col = blockIdx.x * blockDim.x + threadIdx.x; 7 | 8 | if (row < M && col < K) { 9 | float sum = 0.0f; 10 | for (int i = 0; i < N; i++) { 11 | sum += A[row * N + i] * B[i * K + col]; 12 | } 13 | C[row * K + col] = sum; 14 | } 15 | } 16 | 17 | // A, B, C are device pointers (i.e. pointers to memory on the GPU) 18 | void solve(const float* A, const float* B, float* C, int M, int N, int K) { 19 | dim3 threadsPerBlock(32, 32); 20 | dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x, 21 | (M + threadsPerBlock.y - 1) / threadsPerBlock.y); 22 | 23 | matrix_multiplication_kernel<<>>(A, B, C, M, N, K); 24 | cudaDeviceSynchronize(); 25 | } 26 | -------------------------------------------------------------------------------- /02-matrix-multiplication/MOJO/native.mojo: -------------------------------------------------------------------------------- 1 | from gpu.host import DeviceContext 2 | from gpu.id import block_dim, block_idx, thread_idx 3 | from memory import UnsafePointer 4 | from math.math import ceildiv 5 | 6 | @parameter 7 | fn matrix_multiplication_kernel(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], M: Int32, N: Int32, K: Int32): 8 | var col = block_idx.x * block_dim.x + thread_idx.x 9 | var row = block_idx.y * block_dim.y + thread_idx.y 10 | 11 | if Int32(row) < M and Int32(col) < K: 12 | var sum : Float32 = 0.0 13 | for i in range(N): 14 | sum += A[row * N + i] * B[i * K + col] 15 | C[row * K + col] = sum 16 | 17 | # A, B, C are device pointers (i.e. pointers to memory on the GPU) 18 | @export 19 | def solve(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], M: Int32, N: Int32, K: Int32): 20 | var BLOCK_SIZE: Int32 = 16 21 | var ctx = DeviceContext() 22 | 23 | var grid_dim_x = ceildiv(K, BLOCK_SIZE) 24 | var grid_dim_y = ceildiv(M, BLOCK_SIZE) 25 | 26 | ctx.enqueue_function[matrix_multiplication_kernel]( 27 | A, B, C, M, N, K, 28 | grid_dim = (grid_dim_x, grid_dim_y), 29 | block_dim = (BLOCK_SIZE, BLOCK_SIZE) 30 | ) 31 | 32 | ctx.synchronize() 33 | -------------------------------------------------------------------------------- /02-matrix-multiplication/Triton/native.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def matrix_multiplication_kernel( 7 | a_ptr, b_ptr, c_ptr, 8 | M, N, K, 9 | stride_am, stride_an, 10 | stride_bn, stride_bk, 11 | stride_cm, stride_ck, 12 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr 13 | ): 14 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 15 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 16 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 17 | pid_k = tl.program_id(axis=0) 18 | pid_m = tl.program_id(axis=1) 19 | 20 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 21 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 22 | 23 | # initialize pointers for a and b 24 | a_ptrs = a_ptr + offs_m[:, None] * stride_am 25 | b_ptrs = b_ptr + offs_k[None, :] * stride_bk 26 | 27 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 28 | 29 | # accumulate along the N dimension 30 | for n in range(N): 31 | # load current blocks of a and b with boundary check 32 | a = tl.load(a_ptrs) 33 | b = tl.load(b_ptrs) 34 | accumulator += a * b 35 | a_ptrs += stride_an 36 | b_ptrs += stride_bn 37 | 38 | # write result back to c 39 | c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck 40 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 41 | offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 42 | c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K) 43 | tl.store(c_ptrs, accumulator, mask=c_mask) 44 | 45 | # a_ptr, b_ptr, c_ptr are raw device pointers 46 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int): 47 | stride_am, stride_an = N, 1 48 | stride_bn, stride_bk = K, 1 49 | stride_cm, stride_ck = K, 1 50 | 51 | grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), ) 52 | matrix_multiplication_kernel[grid]( 53 | a_ptr, b_ptr, c_ptr, 54 | M, N, K, 55 | stride_am, stride_an, 56 | stride_bn, stride_bk, 57 | stride_cm, stride_ck, 58 | BLOCK_SIZE_M=16, 59 | BLOCK_SIZE_K=16, 60 | ) 61 | -------------------------------------------------------------------------------- /02-matrix-multiplication/Triton/use_tma.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def matrix_multiplication_kernel( 7 | a_ptr, b_ptr, c_ptr, 8 | M, N, K, 9 | stride_am, stride_an, 10 | stride_bn, stride_bk, 11 | stride_cm, stride_ck, 12 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, 13 | GROUP_SIZE_M: tl.constexpr 14 | ): 15 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 16 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 17 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 18 | 19 | pid = tl.program_id(axis=0) 20 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 21 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 22 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 23 | group_id = pid // num_pid_in_group 24 | first_pid_m = group_id * GROUP_SIZE_M 25 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 26 | pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 27 | pid_k = (pid % num_pid_in_group) // group_size_m 28 | 29 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 30 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 31 | offs_n = tl.arange(0, BLOCK_SIZE_N) 32 | 33 | # initialize pointers for a and b 34 | a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an 35 | b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk 36 | 37 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 38 | 39 | # accumulate along the N dimension 40 | for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): 41 | # load current blocks of a and b with boundary check 42 | a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0) 43 | b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0) 44 | # compute matrix multiplication and accumulate 45 | accumulator = tl.dot(a, b, accumulator, input_precision="ieee") 46 | a_ptrs += BLOCK_SIZE_N * stride_an 47 | b_ptrs += BLOCK_SIZE_N * stride_bn 48 | 49 | # write result back to c 50 | c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck 51 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 52 | offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 53 | c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K) 54 | tl.store(c_ptrs, accumulator, mask=c_mask) 55 | 56 | @triton.jit 57 | def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, # 58 | M, N, K, # 59 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, 60 | BLOCK_SIZE_K: tl.constexpr, # 61 | ): 62 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 63 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 64 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 65 | pid = tl.program_id(axis=0) 66 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 67 | pid_m = pid % num_pid_m 68 | pid_n = pid // num_pid_m 69 | offs_am = pid_m * BLOCK_SIZE_M 70 | offs_bn = pid_n * BLOCK_SIZE_N 71 | offs_k = 0 72 | 73 | a_desc = tl._experimental_make_tensor_descriptor( 74 | a_ptr, 75 | shape=[M, K], 76 | strides=[K, 1], 77 | block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], 78 | ) 79 | b_desc = tl._experimental_make_tensor_descriptor( 80 | b_ptr, 81 | shape=[K, N], 82 | strides=[N, 1], 83 | block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], 84 | ) 85 | c_desc = tl._experimental_make_tensor_descriptor( 86 | c_ptr, 87 | shape=[M, N], 88 | strides=[N, 1], 89 | block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], 90 | ) 91 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 92 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 93 | a = a_desc.load([offs_am, offs_k]) 94 | # tl.device_print("a: ", a) 95 | b = b_desc.load([offs_k, offs_bn]) 96 | accumulator = tl.dot(a, b, acc=accumulator, input_precision="ieee") 97 | offs_k += BLOCK_SIZE_K 98 | accumulator = accumulator.to(a_desc.dtype) 99 | c_desc.store([offs_am, offs_bn], accumulator) 100 | 101 | # a_ptr, b_ptr, c_ptr are raw device pointers 102 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int): 103 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(K, META['BLOCK_SIZE_K']), ) 104 | # Leading dimensions must be multiples of 16-byte strides 105 | if M % 4 == 0 and N % 4 == 0 and K % 4 == 0: 106 | # alloc_fn need use cudaMalloc by ctypes in LeetGPU 107 | import ctypes 108 | cudart = ctypes.CDLL("libcudart.so") 109 | cudart.cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t] 110 | cudart.cudaMalloc.restype = ctypes.c_int 111 | from typing import Optional 112 | # TMA descriptors require a global memory allocation 113 | def alloc_fn(size: int, alignment: int, stream: Optional[int]): 114 | ptr = ctypes.c_void_p() 115 | err = cudart.cudaMalloc(ctypes.byref(ptr), size) 116 | if err != 0: 117 | raise RuntimeError(f"cudaMalloc failed, code {err}") 118 | return ptr.value 119 | 120 | triton.set_allocator(alloc_fn) 121 | matmul_kernel_make_tensor_desciptor[grid]( 122 | a_ptr, b_ptr, c_ptr, 123 | M, K, N, 124 | BLOCK_SIZE_M=32, 125 | BLOCK_SIZE_K=32, 126 | BLOCK_SIZE_N=32, 127 | ) 128 | else: 129 | matrix_multiplication_kernel[grid]( 130 | a_ptr, b_ptr, c_ptr, 131 | M, N, K, 132 | N, 1, 133 | K, 1, 134 | K, 1, 135 | BLOCK_SIZE_M=64, 136 | BLOCK_SIZE_K=64, 137 | BLOCK_SIZE_N=64, 138 | GROUP_SIZE_M=8 139 | ) 140 | -------------------------------------------------------------------------------- /02-matrix-multiplication/Triton/with_dot_v1.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def matrix_multiplication_kernel( 7 | a_ptr, b_ptr, c_ptr, 8 | M, N, K, 9 | stride_am, stride_an, 10 | stride_bn, stride_bk, 11 | stride_cm, stride_ck, 12 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr 13 | ): 14 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 15 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 16 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 17 | pid_m = tl.program_id(axis=0) 18 | pid_k = tl.program_id(axis=1) 19 | 20 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 21 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 22 | offs_n = tl.arange(0, BLOCK_SIZE_N) 23 | 24 | # initialize pointers for a and b 25 | a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an 26 | b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk 27 | 28 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 29 | 30 | # accumulate along the N dimension 31 | for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): 32 | # load current blocks of a and b with boundary check 33 | a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0) 34 | b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0) 35 | # compute matrix multiplication and accumulate 36 | accumulator = tl.dot(a, b, accumulator, input_precision="ieee") 37 | a_ptrs += BLOCK_SIZE_N * stride_an 38 | b_ptrs += BLOCK_SIZE_N * stride_bn 39 | 40 | # write result back to c 41 | c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck 42 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 43 | offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 44 | c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K) 45 | tl.store(c_ptrs, accumulator, mask=c_mask) 46 | 47 | # a_ptr, b_ptr, c_ptr are raw device pointers 48 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int): 49 | stride_am, stride_an = N, 1 50 | stride_bn, stride_bk = K, 1 51 | stride_cm, stride_ck = K, 1 52 | 53 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(K, META['BLOCK_SIZE_K'])) 54 | matrix_multiplication_kernel[grid]( 55 | a_ptr, b_ptr, c_ptr, 56 | M, N, K, 57 | stride_am, stride_an, 58 | stride_bn, stride_bk, 59 | stride_cm, stride_ck, 60 | BLOCK_SIZE_M=16, 61 | BLOCK_SIZE_K=16, 62 | BLOCK_SIZE_N=16 63 | ) 64 | -------------------------------------------------------------------------------- /02-matrix-multiplication/Triton/with_dot_v2.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def matrix_multiplication_kernel( 7 | a_ptr, b_ptr, c_ptr, 8 | M, N, K, 9 | stride_am, stride_an, 10 | stride_bn, stride_bk, 11 | stride_cm, stride_ck, 12 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, 13 | GROUP_SIZE_M: tl.constexpr 14 | ): 15 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 16 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 17 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 18 | 19 | pid = tl.program_id(axis=0) 20 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 21 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 22 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 23 | group_id = pid // num_pid_in_group 24 | first_pid_m = group_id * GROUP_SIZE_M 25 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 26 | pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 27 | pid_k = (pid % num_pid_in_group) // group_size_m 28 | 29 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 30 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 31 | offs_n = tl.arange(0, BLOCK_SIZE_N) 32 | 33 | # initialize pointers for a and b 34 | a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an 35 | b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk 36 | 37 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 38 | 39 | # accumulate along the N dimension 40 | for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): 41 | # load current blocks of a and b with boundary check 42 | a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0) 43 | b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0) 44 | # compute matrix multiplication and accumulate 45 | accumulator = tl.dot(a, b, accumulator, input_precision="ieee") 46 | a_ptrs += BLOCK_SIZE_N * stride_an 47 | b_ptrs += BLOCK_SIZE_N * stride_bn 48 | 49 | # write result back to c 50 | c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck 51 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 52 | offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 53 | c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K) 54 | tl.store(c_ptrs, accumulator, mask=c_mask) 55 | 56 | # a_ptr, b_ptr, c_ptr are raw device pointers 57 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int): 58 | stride_am, stride_an = N, 1 59 | stride_bn, stride_bk = K, 1 60 | stride_cm, stride_ck = K, 1 61 | 62 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(K, META['BLOCK_SIZE_K']), ) 63 | matrix_multiplication_kernel[grid]( 64 | a_ptr, b_ptr, c_ptr, 65 | M, N, K, 66 | stride_am, stride_an, 67 | stride_bn, stride_bk, 68 | stride_cm, stride_ck, 69 | BLOCK_SIZE_M=64, 70 | BLOCK_SIZE_K=64, 71 | BLOCK_SIZE_N=64, 72 | GROUP_SIZE_M=8 73 | ) 74 | -------------------------------------------------------------------------------- /02-matrix-multiplication/Triton/with_dot_v3.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.jit 5 | def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M): 6 | group_id = tile_id // num_pid_in_group 7 | first_pid_m = group_id * GROUP_SIZE_M 8 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 9 | pid_m = first_pid_m + (tile_id % group_size_m) 10 | pid_k = (tile_id % num_pid_in_group) // group_size_m 11 | return pid_m, pid_k 12 | 13 | @triton.jit 14 | def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # 15 | M, N, K, # 16 | stride_am, stride_an, # 17 | stride_bn, stride_bk, # 18 | stride_cm, stride_ck, # 19 | BLOCK_SIZE_M: tl.constexpr, # 20 | BLOCK_SIZE_K: tl.constexpr, # 21 | BLOCK_SIZE_N: tl.constexpr, # 22 | GROUP_SIZE_M: tl.constexpr, # 23 | NUM_SMS: tl.constexpr, # 24 | ): 25 | a_ptr = a_ptr.to(tl.pointer_type(tl.float32)) 26 | b_ptr = b_ptr.to(tl.pointer_type(tl.float32)) 27 | c_ptr = c_ptr.to(tl.pointer_type(tl.float32)) 28 | start_pid = tl.program_id(axis=0) 29 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 30 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 31 | n_tiles = tl.cdiv(N, BLOCK_SIZE_N) 32 | num_tiles = num_pid_m * num_pid_k 33 | 34 | # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being 35 | # used in both the prologue and epilogue, so we duplicate the counters as a work-around. 36 | tile_id_c = start_pid - NUM_SMS 37 | 38 | offs_n_for_mask = tl.arange(0, BLOCK_SIZE_N) 39 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 40 | 41 | for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): 42 | pid_m, pid_k = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M) 43 | start_m = pid_m * BLOCK_SIZE_M 44 | start_k = pid_k * BLOCK_SIZE_K 45 | offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) 46 | offs_bk = start_k + tl.arange(0, BLOCK_SIZE_K) 47 | offs_am = tl.where(offs_am < M, offs_am, 0) 48 | offs_bk = tl.where(offs_bk < K, offs_bk, 0) 49 | offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) 50 | offs_bk = tl.max_contiguous(tl.multiple_of(offs_bk, BLOCK_SIZE_K), BLOCK_SIZE_K) 51 | 52 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 53 | for ni in range(n_tiles): 54 | offs_k = ni * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 55 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_an) 56 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bn + offs_bk[None, :] * stride_bk) 57 | 58 | a = tl.load(a_ptrs, mask=offs_n_for_mask[None, :] < N - ni * BLOCK_SIZE_N, other=0.0) 59 | b = tl.load(b_ptrs, mask=offs_n_for_mask[:, None] < N - ni * BLOCK_SIZE_N, other=0.0) 60 | accumulator = tl.dot(a, b, accumulator, input_precision="ieee") 61 | 62 | tile_id_c += NUM_SMS 63 | pid_m, pid_k = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M) 64 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 65 | offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 66 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_ck * offs_ck[None, :] 67 | c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K) 68 | c = accumulator.to(tl.float32) 69 | tl.store(c_ptrs, c, mask=c_mask) 70 | 71 | # a_ptr, b_ptr, c_ptr are raw device pointers 72 | def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int): 73 | stride_am, stride_an = N, 1 74 | stride_bn, stride_bk = K, 1 75 | stride_cm, stride_ck = K, 1 76 | # SM count for Tesla T4 (avoid torch API) 77 | # you can get the value with: torch.cuda.get_device_properties("cuda").multi_processor_count 78 | NUM_SMS = 40 79 | # 1D launch kernel where each block gets its own program. 80 | grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"])), ) 81 | kernel = matmul_kernel_persistent[grid]( 82 | a_ptr, b_ptr, c_ptr, # 83 | M, N, K, # 84 | stride_am, stride_an, 85 | stride_bn, stride_bk, 86 | stride_cm, stride_ck, 87 | BLOCK_SIZE_M=64, 88 | BLOCK_SIZE_K=64, 89 | BLOCK_SIZE_N=64, 90 | GROUP_SIZE_M=8, 91 | NUM_SMS=NUM_SMS, # 92 | ) 93 | -------------------------------------------------------------------------------- /03-matrix-transpose/CUDA/native.cu: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | #include 3 | 4 | __global__ void matrix_transpose_kernel(const float* input, float* output, int rows, int cols) { 5 | int x = blockIdx.x * blockDim.x + threadIdx.x; // Column index 6 | int y = blockIdx.y * blockDim.y + threadIdx.y; // Row index 7 | 8 | if (x < cols && y < rows) { 9 | output[x * rows + y] = input[y * cols + x]; // Transpose operation 10 | } 11 | } 12 | 13 | // input, output are device pointers (i.e. pointers to memory on the GPU) 14 | void solve(const float* input, float* output, int rows, int cols) { 15 | const int BLOCK_SIZE = 16; 16 | dim3 threadsPerBlock(BLOCK_SIZE, BLOCK_SIZE); 17 | dim3 blocksPerGrid((cols + BLOCK_SIZE - 1) / BLOCK_SIZE, 18 | (rows + BLOCK_SIZE - 1) / BLOCK_SIZE); 19 | 20 | matrix_transpose_kernel<<>>(input, output, rows, cols); 21 | cudaDeviceSynchronize(); 22 | } 23 | -------------------------------------------------------------------------------- /03-matrix-transpose/CUDA/use_shared.cu: -------------------------------------------------------------------------------- 1 | // refer to https://zhuanlan.zhihu.com/p/692010210 for more details. 2 | 3 | #include "solve.h" 4 | #include 5 | 6 | template 7 | __global__ void mat_transpose_kernel_v3(const float* idata, float* odata, int M, int N) { 8 | const int bx = blockIdx.x, by = blockIdx.y; 9 | const int tx = threadIdx.x, ty = threadIdx.y; 10 | 11 | __shared__ float sdata[BLOCK_SZ][BLOCK_SZ+1]; 12 | 13 | int x = bx * BLOCK_SZ + tx; 14 | int y = by * BLOCK_SZ + ty; 15 | 16 | constexpr int ROW_STRIDE = BLOCK_SZ / NUM_PER_THREAD; 17 | 18 | if (x < N) { 19 | #pragma unroll 20 | for (int y_off = 0; y_off < BLOCK_SZ; y_off += ROW_STRIDE) { 21 | if (y + y_off < M) { 22 | sdata[ty + y_off][tx] = idata[(y + y_off) * N + x]; 23 | } 24 | } 25 | } 26 | __syncthreads(); 27 | x = by * BLOCK_SZ + tx; 28 | y = bx * BLOCK_SZ + ty; 29 | if (x < M) { 30 | for (int y_off = 0; y_off < BLOCK_SZ; y_off += ROW_STRIDE) { 31 | if (y + y_off < N) { 32 | odata[(y + y_off) * M + x] = sdata[tx][ty + y_off]; 33 | } 34 | } 35 | } 36 | } 37 | 38 | // input, output are device pointers (i.e. pointers to memory on the GPU) 39 | void solve(const float* input, float* output, int rows, int cols) { 40 | constexpr int BLOCK_SZ = 32; 41 | constexpr int NUM_PER_THREAD = 4; 42 | dim3 block(BLOCK_SZ, BLOCK_SZ/NUM_PER_THREAD); 43 | dim3 grid((cols+ BLOCK_SZ-1)/BLOCK_SZ, (rows+BLOCK_SZ-1)/BLOCK_SZ); 44 | mat_transpose_kernel_v3<<>>(input, output, rows, cols); 45 | cudaDeviceSynchronize(); 46 | } -------------------------------------------------------------------------------- /03-matrix-transpose/MOJO/native.mojo: -------------------------------------------------------------------------------- 1 | from gpu.host import DeviceContext 2 | from gpu.id import block_dim, block_idx, thread_idx 3 | from memory import UnsafePointer 4 | from math import ceildiv 5 | 6 | fn matrix_transpose_kernel(input: UnsafePointer[Float32], output: UnsafePointer[Float32], rows: Int32, cols: Int32): 7 | var col = block_idx.x * block_dim.x + thread_idx.x 8 | var row = block_idx.y * block_dim.y + thread_idx.y 9 | if Int32(col) < cols and Int32(row) < rows: 10 | output[col * rows + row] = input[row * cols + col] 11 | 12 | # input, output are device pointers (i.e. pointers to memory on the GPU) 13 | @export 14 | def solve(input: UnsafePointer[Float32], output: UnsafePointer[Float32], rows: Int32, cols: Int32): 15 | var BLOCK_SIZE: Int32 = 32 16 | var ctx = DeviceContext() 17 | 18 | var grid_dim_x = ceildiv(cols, BLOCK_SIZE) 19 | var grid_dim_y = ceildiv(rows, BLOCK_SIZE) 20 | 21 | ctx.enqueue_function[matrix_transpose_kernel]( 22 | input, output, rows, cols, 23 | grid_dim = (grid_dim_x, grid_dim_y), 24 | block_dim = (BLOCK_SIZE, BLOCK_SIZE) 25 | ) 26 | 27 | ctx.synchronize() 28 | -------------------------------------------------------------------------------- /03-matrix-transpose/Triton/native.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.jit 5 | def matrix_transpose_kernel( 6 | input_ptr, output_ptr, 7 | M, N, 8 | stride_ir, stride_ic, 9 | stride_or, stride_oc, 10 | BLOCK_SIZE : tl.constexpr 11 | ): 12 | input_ptr = input_ptr.to(tl.pointer_type(tl.float32)) 13 | output_ptr = output_ptr.to(tl.pointer_type(tl.float32)) 14 | # 1. determine the input tile coordinates this thread block is responsible for 15 | pid_m = tl.program_id(0) # block index in M dimension 16 | pid_n = tl.program_id(1) # block index in N dimension 17 | 18 | # 2. compute element-wise offsets within the tile 19 | offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 20 | offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 21 | 22 | # 3. define global memory pointers for input tile (row-major) 23 | input_ptrs = input_ptr + offs_m[:, None] * stride_ir + offs_n[None, :] * stride_ic 24 | 25 | # 4. load input tile from global memory with boundary check 26 | mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 27 | block = tl.load(input_ptrs, mask=mask, other=0) 28 | 29 | # 5. transpose the tile (swap rows and columns) 30 | transposed_block = tl.trans(block) # Triton built-in transpose function 31 | 32 | # 6. compute global memory pointers for output tile (column-major) 33 | output_ptrs = output_ptr + offs_n[:, None] * M + offs_m[None, :] # M is row stride after transpose 34 | 35 | # 7. store the transposed tile to global memory 36 | tl.store(output_ptrs, transposed_block, mask=mask.T) # transpose mask as well 37 | 38 | 39 | # input_ptr, output_ptr are raw device pointers 40 | def solve(input_ptr: int, output_ptr: int, rows: int, cols: int): 41 | stride_ir, stride_ic = cols, 1 42 | stride_or, stride_oc = rows, 1 43 | 44 | grid = lambda META: (triton.cdiv(rows, META['BLOCK_SIZE']), triton.cdiv(cols, META['BLOCK_SIZE'])) 45 | matrix_transpose_kernel[grid]( 46 | input_ptr, output_ptr, 47 | rows, cols, 48 | stride_ir, stride_ic, 49 | stride_or, stride_oc, 50 | BLOCK_SIZE=32 51 | ) 52 | -------------------------------------------------------------------------------- /13-softmax/CUDA/native.cu: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | #include 3 | #include 4 | 5 | __global__ void softmax_kernel(const float* input, float* output, int N) { 6 | // dynamic shared memory for reduction 7 | extern __shared__ float shared_mem[]; 8 | float* max_shared = shared_mem; 9 | float* sum_shared = &shared_mem[blockDim.x]; 10 | 11 | int tid = threadIdx.x; 12 | 13 | // compute maximum value 14 | float local_max = -FLT_MAX; 15 | for (int i = tid; i < N; i += blockDim.x) { 16 | local_max = fmaxf(local_max, input[i]); 17 | } 18 | max_shared[tid] = local_max; 19 | __syncthreads(); 20 | 21 | // block reduction to find maximum 22 | for (int s = blockDim.x/2; s > 0; s >>= 1) { 23 | if (tid < s) { 24 | max_shared[tid] = fmaxf(max_shared[tid], max_shared[tid + s]); 25 | } 26 | __syncthreads(); 27 | } 28 | float max_val = max_shared[0]; 29 | __syncthreads(); 30 | 31 | // compute exponential sum 32 | float local_sum = 0.0f; 33 | for (int i = tid; i < N; i += blockDim.x) { 34 | local_sum += expf(input[i] - max_val); 35 | } 36 | sum_shared[tid] = local_sum; 37 | __syncthreads(); 38 | 39 | // block reduction to compute sum 40 | for (int s = blockDim.x/2; s > 0; s >>= 1) { 41 | if (tid < s) { 42 | sum_shared[tid] += sum_shared[tid + s]; 43 | } 44 | __syncthreads(); 45 | } 46 | float sum_exp = sum_shared[0]; 47 | __syncthreads(); 48 | 49 | // compute final results 50 | for (int i = tid; i < N; i += blockDim.x) { 51 | output[i] = expf(input[i] - max_val) / sum_exp; 52 | } 53 | } 54 | 55 | // input, output are device pointers (i.e. pointers to memory on the GPU) 56 | void solve(const float* input, float* output, int N) { 57 | int threadsPerBlock = 256; 58 | int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock; 59 | size_t sharedSize = 2 * threadsPerBlock * sizeof(float); 60 | softmax_kernel<<>>(input, output, N); 61 | cudaDeviceSynchronize(); 62 | } 63 | -------------------------------------------------------------------------------- /13-softmax/MOJO/native.mojo: -------------------------------------------------------------------------------- 1 | from gpu.host import DeviceContext 2 | from gpu.id import block_dim, block_idx, thread_idx 3 | from gpu import barrier 4 | from memory import UnsafePointer 5 | from buffer import NDBuffer 6 | from math import ceildiv, exp 7 | from gpu.memory import AddressSpace 8 | from algorithm._gpu.reduction import block_reduce 9 | 10 | @parameter 11 | fn softmax_kernel(input: UnsafePointer[Float32], output: UnsafePointer[Float32], N: Int32): 12 | alias BLOCK_SIZE: Int = 1024 13 | var tid = thread_idx.x 14 | if tid == 0: 15 | output[0] = 1 16 | var max_buf = NDBuffer[ 17 | DType.float32, 1, MutableAnyOrigin, 1, address_space = AddressSpace.SHARED 18 | ].stack_allocation() 19 | var sum_buf = NDBuffer[ 20 | DType.float32, 1, MutableAnyOrigin, 1, address_space = AddressSpace.SHARED 21 | ].stack_allocation() 22 | 23 | # Step 1: compute max 24 | var local_max = Scalar[DType.float32](input[tid]) 25 | for i in range(tid + BLOCK_SIZE, N, BLOCK_SIZE): 26 | local_max = max(local_max, input[i]) 27 | 28 | @parameter 29 | @always_inline 30 | fn _max[ 31 | type: DType, width: Int 32 | ](x: SIMD[type, width], y: SIMD[type, width]) -> SIMD[type, width]: 33 | return max(x,y) 34 | 35 | var block_max = block_reduce[BLOCK_SIZE, _max](local_max, 0) 36 | 37 | if tid == 0: 38 | max_buf[0] = block_max 39 | barrier() 40 | 41 | # Step 2: out[i] = exp(in[i] - max) and compute sum of out[i] 42 | var local_sum = Scalar[DType.float32](0) 43 | for i in range(tid, N, BLOCK_SIZE): 44 | local_sum += exp(input[i] - max_buf[0]) 45 | 46 | @parameter 47 | @always_inline 48 | fn _sum[ 49 | type: DType, width: Int 50 | ](x: SIMD[type, width], y: SIMD[type, width]) -> SIMD[type, width]: 51 | return x+y 52 | 53 | var block_sum = block_reduce[BLOCK_SIZE, _sum](local_sum, 0) 54 | 55 | if tid == 0: 56 | sum_buf[0] = block_sum 57 | barrier() 58 | 59 | # Step 3: Normalize output 60 | for i in range(tid, N, BLOCK_SIZE): 61 | output[i] = exp(input[i] - max_buf[0]) / sum_buf[0] 62 | 63 | @export 64 | def solve(input: UnsafePointer[Float32], output: UnsafePointer[Float32], N: Int32): 65 | var BLOCK_SIZE: Int32 = 1024 66 | var ctx = DeviceContext() 67 | var num_blocks = ceildiv(N, BLOCK_SIZE) 68 | 69 | ctx.enqueue_function[softmax_kernel]( 70 | input, output, N, 71 | grid_dim = num_blocks, 72 | block_dim = BLOCK_SIZE 73 | ) 74 | 75 | ctx.synchronize() 76 | -------------------------------------------------------------------------------- /13-softmax/Triton/reduce_in_one_block.py: -------------------------------------------------------------------------------- 1 | # The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking. 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def softmax_kernel( 7 | input_ptr, output_ptr, 8 | N, 9 | BLOCK_SIZE: tl.constexpr 10 | ): 11 | input_ptr = input_ptr.to(tl.pointer_type(tl.float32)) 12 | output_ptr = output_ptr.to(tl.pointer_type(tl.float32)) 13 | _max = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf") 14 | for off in range(0, N, BLOCK_SIZE): 15 | cols = off + tl.arange(0, BLOCK_SIZE) 16 | a = tl.load(input_ptr + cols, mask=cols < N, other=-float("inf")) 17 | _max = tl.maximum(a, _max) 18 | max = tl.max(_max, axis=0) 19 | _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 20 | for off in range(0, N, BLOCK_SIZE): 21 | cols = off + tl.arange(0, BLOCK_SIZE) 22 | a = tl.load(input_ptr + cols, mask=cols < N, other=-float("inf")) 23 | _sum += tl.exp(a - max) 24 | sum = tl.sum(_sum, axis=0) 25 | pid = tl.program_id(0) 26 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 27 | mask = offset < N 28 | x = tl.load(input_ptr + offset, mask=mask) 29 | y = tl.exp(x - max) / sum 30 | tl.store(output_ptr + offset, y, mask=mask) 31 | 32 | # input_ptr, output_ptr are raw device pointers 33 | def solve(input_ptr: int, output_ptr: int, N: int): 34 | BLOCK_SIZE = 32768 35 | grid = (triton.cdiv(N, BLOCK_SIZE),) 36 | softmax_kernel[grid]( 37 | input_ptr, output_ptr, N, 38 | BLOCK_SIZE=BLOCK_SIZE 39 | ) 40 | -------------------------------------------------------------------------------- /13-softmax/Triton/three_kernel.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.jit 5 | def partial_max_value_kernel(X, partial_max, N, BLOCK_SIZE: tl.constexpr): 6 | X = X.to(tl.pointer_type(tl.float32)) 7 | partial_max = partial_max.to(tl.pointer_type(tl.float32)) 8 | pid = tl.program_id(0) 9 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 10 | mask = offset < N 11 | x = tl.load(X + offset, mask=mask, other=-float("inf")) 12 | local_max = tl.max(x, axis=0) 13 | tl.store(partial_max + pid, local_max) 14 | 15 | @triton.jit 16 | def partial_exp_sum_value_kernel(X, partial_sum, global_max, N, BLOCK_SIZE: tl.constexpr): 17 | X = X.to(tl.pointer_type(tl.float32)) 18 | partial_sum = partial_sum.to(tl.pointer_type(tl.float32)) 19 | global_max = global_max.to(tl.pointer_type(tl.float32)) 20 | pid = tl.program_id(0) 21 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 22 | mask = offset < N 23 | x = tl.load(X + offset, mask=mask, other=-float("inf")) 24 | gmax = tl.load(global_max) 25 | local_sum = tl.sum(tl.exp(x - gmax), axis=0) 26 | tl.store(partial_sum + pid, local_sum) 27 | 28 | @triton.jit 29 | def normalize_kernel(X, Y, N, global_max, global_sum, BLOCK_SIZE: tl.constexpr): 30 | X = X.to(tl.pointer_type(tl.float32)) 31 | Y = Y.to(tl.pointer_type(tl.float32)) 32 | global_max = global_max.to(tl.pointer_type(tl.float32)) 33 | global_sum = global_sum.to(tl.pointer_type(tl.float32)) 34 | pid = tl.program_id(0) 35 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 36 | mask = offset < N 37 | x = tl.load(X + offset, mask=mask) 38 | gmax = tl.load(global_max) 39 | gsum = tl.load(global_sum) 40 | y = tl.exp(x - gmax) / gsum 41 | tl.store(Y + offset, y, mask=mask) 42 | 43 | @triton.jit 44 | def get_max_value(partial_max, global_max, BLOCK_SIZE: tl.constexpr): 45 | partial_max = partial_max.to(tl.pointer_type(tl.float32)) 46 | global_max = global_max.to(tl.pointer_type(tl.float32)) 47 | offset = tl.arange(0, BLOCK_SIZE) 48 | x = tl.load(partial_max + offset) 49 | local_max = tl.max(x, axis=0) 50 | tl.store(global_max, local_max) 51 | 52 | @triton.jit 53 | def get_sum_value(partial_sum, global_sum, BLOCK_SIZE: tl.constexpr): 54 | partial_sum = partial_sum.to(tl.pointer_type(tl.float32)) 55 | global_sum = global_sum.to(tl.pointer_type(tl.float32)) 56 | offset = tl.arange(0, BLOCK_SIZE) 57 | x = tl.load(partial_sum + offset) 58 | local_sum = tl.sum(x, axis=0) 59 | tl.store(global_sum, local_sum) 60 | 61 | def cudaEmpty(num_elements:int): 62 | import ctypes 63 | cudart = ctypes.CDLL("libcudart.so") 64 | cudart.cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t] 65 | cudart.cudaMalloc.restype = ctypes.c_int 66 | ptr = ctypes.c_void_p() 67 | err = cudart.cudaMalloc(ctypes.byref(ptr), num_elements*4) 68 | if err != 0: 69 | raise RuntimeError(f"cudaMalloc failed, code {err}") 70 | return ptr.value 71 | 72 | 73 | # input_ptr, output_ptr are raw device pointers 74 | def solve(input_ptr: int, output_ptr: int, N: int): 75 | BLOCK_SIZE = 32768 76 | num_blocks = triton.cdiv(N, BLOCK_SIZE) 77 | grid = (num_blocks,) 78 | partial_max = cudaEmpty(BLOCK_SIZE) 79 | partial_max_value_kernel[grid]( 80 | input_ptr, partial_max, N, 81 | BLOCK_SIZE=BLOCK_SIZE 82 | ) 83 | global_max = cudaEmpty(1) 84 | get_max_value[1,](partial_max, global_max, BLOCK_SIZE=num_blocks) 85 | partial_sum = cudaEmpty(num_blocks) 86 | partial_exp_sum_value_kernel[grid]( 87 | input_ptr, partial_sum, global_max, N, 88 | BLOCK_SIZE=BLOCK_SIZE 89 | ) 90 | global_sum = cudaEmpty(1) 91 | get_sum_value[1,](partial_sum, global_sum, BLOCK_SIZE=num_blocks) 92 | normalize_kernel[grid]( 93 | input_ptr, output_ptr, N, 94 | global_max, global_sum, 95 | BLOCK_SIZE=BLOCK_SIZE 96 | ) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LeetGPU Solutions 2 | 3 | This repository contains my personal solutions to various problems from [LeetGPU](https://leetgpu.com/challenges). The solutions are written primarily in **Triton** (or specify CUDA/MOJO🔥) and are categorized by problem. My nickname is [BobHuang](https://leetgpu.com/profile?display_name=BobHuang). 4 | 5 | If you're a beginner just getting started with algorithm problems and you're comfortable reading Chinese, feel free to check out my introductory tutorial on Zhihu: 👉 [LeetGPU入门教程 (CUDA guide最佳实践)](https://zhuanlan.zhihu.com/p/1899956367734867434) [LeetGPU的MOJO 🔥 实践](https://zhuanlan.zhihu.com/p/1908980999993402643) 6 | -------------------------------------------------------------------------------- /utils/createFolder.py: -------------------------------------------------------------------------------- 1 | from playwright.sync_api import sync_playwright 2 | import time 3 | import os 4 | script_dir = os.path.dirname(os.path.abspath(__file__)) 5 | parent_dir = os.path.dirname(script_dir) 6 | 7 | found = [] 8 | with sync_playwright() as p: 9 | wait_seconds = 20 10 | time_out_seconds = min(wait_seconds//4, 5) 11 | browser = p.chromium.launch(headless=True) 12 | page = browser.new_page() 13 | 14 | page.set_default_navigation_timeout((wait_seconds + time_out_seconds) * 1000) 15 | page.set_default_timeout((wait_seconds + time_out_seconds) * 1000) 16 | page.goto("https://leetgpu.com/challenges", wait_until="domcontentloaded") 17 | 18 | time.sleep(wait_seconds) 19 | 20 | anchors = page.query_selector_all('a[href^="/challenges/"]') 21 | 22 | for a in anchors: 23 | href = a.get_attribute("href") 24 | if href: 25 | name = href.split("/")[-1] 26 | if name and name not in found: 27 | found.append(name) 28 | 29 | browser.close() 30 | 31 | for idx, item in enumerate(found): 32 | folder_path = os.path.join(parent_dir, f"{idx+1:02d}-{item}") 33 | if not os.path.exists(folder_path): 34 | os.mkdir(folder_path) 35 | 36 | digital_folders = [ 37 | name for name in os.listdir(parent_dir) 38 | if os.path.isdir(os.path.join(parent_dir, name)) 39 | and name[0].isdigit() and name[1].isdigit() 40 | and name[2] == '-' 41 | ] 42 | 43 | for item in digital_folders: 44 | folder_path = os.path.join(parent_dir, item, "Triton") 45 | if not os.path.exists(folder_path): 46 | os.mkdir(folder_path) 47 | file_path = os.path.join(folder_path, "native.py") 48 | if not os.path.exists(file_path): 49 | open(file_path, 'w') 50 | --------------------------------------------------------------------------------