├── readme.md ├── .clang-format ├── makefile ├── headers ├── host │ ├── profile_utilities.cuh │ ├── tma_tensor_map.cuh │ └── matrix_utilities.cuh └── device │ ├── tma.cuh │ └── wgmma.cuh ├── examples ├── 1_cluster.cu ├── 5_tma_2d.cu ├── 7_reduce_store.cu ├── 4_tma_1d.cu ├── 2_wgmma_dense.cu ├── 10_tma_and_wgmma.cu ├── test.cu ├── 6_multicast.cu ├── 3_wgmma_sparse.cu ├── 9_swizzle.cu └── 8_swizzle_manual.cu ├── dense ├── 2_m64_n16_k16.cu ├── 1_m64_n8_k32.cu ├── 4_m64_n16_k64.cu └── 3_m64_n8_k64.cu └── sparse ├── 1_m64_n8_k32.cu ├── 3_m256_n8_k64.cu ├── 2_m64_n8_k64.cu ├── 4_m256_n16_k64.cu └── 5_m256_n32_k64.cu /readme.md: -------------------------------------------------------------------------------- 1 | # Compile and run: 2 | 3 | `make KERNEL=` 4 | 5 | `./run` -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | UseTab: Always # Use tabs for indentation 2 | TabWidth: 4 # Number of spaces for a tab 3 | IndentWidth: 4 # Indentation size -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | sm_version=90a 2 | NVCC=/usr/local/cuda-12.4/bin/nvcc 3 | INCLUDES=-I./headers/device/ -I./headers/host/ 4 | OPTIMIZATION=-O0 5 | LINKS=-lcudart -lcuda 6 | OUTPUT=run 7 | KERNEL=dense/2_m64_n8_k64.cu 8 | COMMENT=update 9 | 10 | all: 11 | make kernel 12 | make run 13 | 14 | kernel: 15 | ${NVCC} -arch=sm_${sm_version} ${OPTIMIZATION} ${INCLUDES} ${LINKS} -o ${OUTPUT} ${KERNEL} 16 | 17 | push: 18 | git add . 19 | git commit -m "${COMMENT}" 20 | git push 21 | 22 | run: 23 | ./${OUTPUT} 24 | 25 | clean: 26 | rm -f ${OUTPUT} -------------------------------------------------------------------------------- /headers/host/profile_utilities.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CUDA(func) \ 4 | { \ 5 | cudaError_t status = (func); \ 6 | if (status != cudaSuccess) \ 7 | { \ 8 | printf("CUDA API failed at line %d with error: %s (%d)\n", \ 9 | __LINE__, cudaGetErrorString(status), status); \ 10 | return EXIT_FAILURE; \ 11 | } \ 12 | } 13 | 14 | void cuda_check_error() 15 | { 16 | cudaDeviceSynchronize(); 17 | 18 | cudaError_t err = cudaGetLastError(); 19 | if (err != cudaSuccess) 20 | { 21 | printf("CUDA error: %s\n", cudaGetErrorString(err)); 22 | } 23 | } 24 | 25 | class cuda_timer 26 | { 27 | private: 28 | cudaEvent_t start, stop; 29 | 30 | public: 31 | cuda_timer() 32 | { 33 | cudaEventCreate(&start); 34 | cudaEventCreate(&stop); 35 | } 36 | 37 | ~cuda_timer() 38 | { 39 | cudaEventDestroy(start); 40 | cudaEventDestroy(stop); 41 | } 42 | 43 | void start_timer() 44 | { 45 | cudaEventRecord(start); 46 | } 47 | 48 | void stop_timer() 49 | { 50 | cudaEventRecord(stop); 51 | cudaEventSynchronize(stop); 52 | } 53 | 54 | float get_time() 55 | { 56 | float time; 57 | cudaEventElapsedTime(&time, start, stop); 58 | return time; 59 | } 60 | }; -------------------------------------------------------------------------------- /examples/1_cluster.cu: -------------------------------------------------------------------------------- 1 | // This code demonstrate on sm_90 GPU, 2 | // how to create a cluster of thread blocks 3 | // and how blocks in a cluster can interact 4 | // using distributed shared memory. 5 | 6 | #include 7 | #include 8 | 9 | #include "profile_utilities.cuh" 10 | 11 | __global__ void __cluster_dims__(2, 1, 1) cluster_kernel() { 12 | // printf("blockIdx.x: %d, threadIdx.x: %d\n", blockIdx.x, threadIdx.x); 13 | 14 | __shared__ int smem[32]; 15 | namespace cg = cooperative_groups; 16 | 17 | // tid is the thread index within the cluster, not block. 18 | int tid = cg::this_grid().thread_rank(); 19 | 20 | cg::cluster_group cluster = cg::this_cluster(); 21 | unsigned int clusterBlockRank = cluster.block_rank(); 22 | int cluster_size = cluster.dim_blocks().x; 23 | 24 | // cluster size = nubmer of blocks in the cluster 25 | if (tid == 0) { 26 | printf("cluster_size: %d\n", cluster_size); 27 | } 28 | 29 | // initialize shared memory, block 1 has one value higher than block 0 30 | smem[threadIdx.x] = blockIdx.x + threadIdx.x; 31 | 32 | cluster.sync(); 33 | 34 | // get the shared memory of the other block 35 | int *other_block_smem = cluster.map_shared_rank(smem, 1 - clusterBlockRank); 36 | 37 | // get the value from the other block 38 | int value = other_block_smem[threadIdx.x]; 39 | 40 | cluster.sync(); 41 | 42 | // print the value 43 | printf("blockIdx.x: %d, threadIdx.x: %d, value: %d\n", blockIdx.x, 44 | threadIdx.x, value); 45 | } 46 | 47 | int main() { 48 | 49 | // two blocks in a cluster 50 | cluster_kernel<<<2, 32>>>(); 51 | 52 | cuda_check_error(); 53 | } 54 | -------------------------------------------------------------------------------- /headers/device/tma.cuh: -------------------------------------------------------------------------------- 1 | // TMA api wrappers 2 | 3 | #pragma once 4 | 5 | // Suppress warning about barrier in shared memory 6 | #pragma nv_diag_suppress static_var_with_dynamic_init 7 | 8 | #include 9 | #include // PFN_cuTensorMapEncodeTiled, CUtensorMap 10 | 11 | enum Cache_Policy 12 | { 13 | evict_normal, 14 | evict_first, 15 | evict_last, 16 | evict_unchanged, 17 | no_allocate, 18 | }; 19 | 20 | /* 21 | // global -> shared::cluster: 22 | cp.async.bulk.prefetch.tensor.dim.L2.src{.load_mode}{.level::cache_hint} [tensorMap, tensorCoords] 23 | {, im2colOffsets } {, cache-policy} 24 | 25 | .src = { .global } 26 | .dim = { .1d, .2d, .3d, .4d, .5d } 27 | .load_mode = { .tile, .im2col } 28 | .level::cache_hint = { .L2::cache_hint } 29 | */ 30 | 31 | // 1d prefetch 32 | // src align to 16, size multiple of 16 33 | __device__ void copy_async_1d_prefetch(const CUtensorMap *__tensor_map, int coordinate) 34 | { 35 | asm volatile( 36 | "cp.async.bulk.prefetch.tensor.dim.L2.src.global.tile" 37 | " [%0, {%1}];" 38 | : 39 | : "l"(__tensor_map), 40 | "r"(coordinate) 41 | : "memory"); 42 | } 43 | 44 | // 2d prefetch 45 | // src align to 16, size multiple of 16 46 | __device__ void copy_async_2d_prefetch(const CUtensorMap *__tensor_map, int coordinate1, int coordinate2) 47 | { 48 | asm volatile( 49 | "cp.async.bulk.prefetch.tensor.2d.L2.global.tile" 50 | " [%0, {%1, %2}];" 51 | : 52 | : "l"(__tensor_map), 53 | "r"(coordinate1), 54 | "r"(coordinate2) 55 | : "memory"); 56 | } 57 | 58 | // inline _LIBCUDACXX_DEVICE void cp_async_bulk_tensor_1d_shared_to_global_multicast( 59 | // void *__dest, 60 | // const CUtensorMap *__tensor_map, 61 | // int __c0, 62 | // ::cuda::barrier<::cuda::thread_scope_block> &__bar, 63 | // uint16_t __ctaMask) 64 | // { 65 | // asm volatile( 66 | // "cp.async.bulk.tensor.1d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster " 67 | // "[%0], [%1, {%2}], [%3], %4;\n" 68 | // : 69 | // : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), 70 | // "l"(&__tensor_map), 71 | // "r"(__c0), 72 | // "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))), 73 | // "h"(__ctaMask) 74 | // : "memory"); 75 | // } -------------------------------------------------------------------------------- /examples/5_tma_2d.cu: -------------------------------------------------------------------------------- 1 | // This code uses TMA's 2d load to load a matrix's tile to 2 | // shared memory and then change the value in the 3 | // shared memory and uses TMA's store to store the 4 | // tile back to global memory. We print the result matrix to prove the 5 | // changes are done 6 | 7 | // note very carefully the order of the m and k coordinate in the api calls 8 | // and note the alignment requirement of the coordinatess 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "matrix_utilities.cuh" 15 | #include "profile_utilities.cuh" 16 | #include "tma.cuh" 17 | #include "tma_tensor_map.cuh" 18 | 19 | // Suppress warning about barrier in shared memory 20 | #pragma nv_diag_suppress static_var_with_dynamic_init 21 | 22 | using barrier = cuda::barrier; 23 | namespace cde = cuda::device::experimental; 24 | 25 | constexpr size_t M = 64; // Number of rows of matrix 26 | constexpr size_t K = 32; // Number of columns of matrix 27 | constexpr size_t gmem_len = M * K; 28 | 29 | constexpr int m = 16; // subtile rows 30 | constexpr int k = 8; // subtile columns 31 | 32 | static constexpr int buf_len = k * m; 33 | 34 | __global__ void test(const __grid_constant__ CUtensorMap tensor_map, int x, 35 | int y) { 36 | __shared__ alignas(128) int smem_buffer[buf_len]; 37 | __shared__ barrier bar; 38 | 39 | if (threadIdx.x == 0) { 40 | init(&bar, blockDim.x); 41 | } 42 | __syncthreads(); 43 | 44 | // Load data: 45 | uint64_t token; 46 | if (threadIdx.x == 0) { 47 | // just to demonstrate using prefetch, completely unnecessary here 48 | copy_async_2d_prefetch(&tensor_map, x, y); 49 | // call the loading api 50 | cde::cp_async_bulk_tensor_2d_global_to_shared(smem_buffer, &tensor_map, 51 | x, y, bar); 52 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 53 | } else { 54 | token = bar.arrive(); 55 | } 56 | 57 | bar.wait(cuda::std::move(token)); 58 | 59 | __syncthreads(); 60 | 61 | // Update subtile, + 1 62 | for (int i = threadIdx.x; i < buf_len; i += blockDim.x) { 63 | smem_buffer[i] += 1; 64 | } 65 | 66 | cde::fence_proxy_async_shared_cta(); 67 | __syncthreads(); 68 | 69 | // Write back to global memory: 70 | if (threadIdx.x == 0) { 71 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map, x, y, 72 | smem_buffer); 73 | cde::cp_async_bulk_commit_group(); 74 | cde::cp_async_bulk_wait_group_read<0>(); 75 | } 76 | __threadfence(); 77 | __syncthreads(); 78 | } 79 | 80 | int main() { 81 | // fill the host matrix 82 | int host_tensor[gmem_len]; 83 | fill_tilewise(host_tensor, M, K, m, k); 84 | 85 | print_matrix(host_tensor, M, K); 86 | 87 | // copy host matrix to device 88 | int *tensor_ptr = nullptr; 89 | cudaMalloc(&tensor_ptr, gmem_len * sizeof(int)); 90 | cudaMemcpy(tensor_ptr, host_tensor, gmem_len * sizeof(int), 91 | cudaMemcpyHostToDevice); 92 | 93 | // create tensor map for the matrix 94 | CUtensorMap tensor_map = create_2d_tensor_map(M, K, m, k, tensor_ptr); 95 | 96 | // launch kernel, select a tile coordinate 97 | // x (0 16 32 48) y (0 8 16 24) must be aligned with m and k 98 | int coordinate_m = 48; 99 | int coordinate_k = 24; 100 | test<<<1, 128>>>(tensor_map, coordinate_k, coordinate_m); 101 | 102 | cuda_check_error(); 103 | 104 | // copy device matrix to host 105 | int host_gmem_tensor[gmem_len]; 106 | cudaMemcpy(host_gmem_tensor, tensor_ptr, gmem_len * sizeof(int), 107 | cudaMemcpyDeviceToHost); 108 | 109 | // verify the results 110 | print_matrix(host_gmem_tensor, M, K); 111 | 112 | return 0; 113 | } 114 | -------------------------------------------------------------------------------- /examples/7_reduce_store.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code uses TMA's 1d tensor load to load 3 | a portion of an array to shared memory and then 4 | change the value in the shared memory and uses TMA's store 5 | to store the portion back to global memory. We print the result 6 | to show the changes are done. 7 | */ 8 | 9 | // supress warning about barrier in shared memory on line 32 10 | #pragma nv_diag_suppress static_var_with_dynamic_init 11 | 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | 19 | using barrier = cuda::barrier; 20 | namespace cde = cuda::device::experimental; 21 | 22 | const int array_size = 128; 23 | const int tile_size = 16; 24 | 25 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 26 | int coordinate) { 27 | // Shared memory buffers for tile. The destination shared memory buffer of 28 | // a bulk operations should be 16 byte aligned. 29 | __shared__ alignas(16) int tile_shared[tile_size]; 30 | 31 | // 4. change the value in shared memory 32 | for (int i = threadIdx.x; i < array_size; i += blockDim.x) { 33 | if (i < tile_size) { 34 | tile_shared[i] = 3; 35 | } 36 | } 37 | 38 | // 5. Wait for shared memory writes to be visible to TMA engine. 39 | cde::fence_proxy_async_shared_cta(); 40 | __syncthreads(); 41 | // After syncthreads, writes by all threads are visible to TMA engine. 42 | 43 | // 6. Initiate TMA transfer to copy shared memory to global memory 44 | if (threadIdx.x == 0) { 45 | // .add, .min, .max, .inc, .dec, .and, .or, .xor 46 | // in here we use .add to demonstrate 47 | // so this instruction will do element wise addition 48 | // with thet tile in the global memory and shared memory and store the 49 | // result back to global memory tile 50 | asm volatile("cp.reduce.async.bulk.tensor.1d.global.shared::cta.add." 51 | "tile.bulk_group " 52 | "[%0, {%1}], [%2];\n" 53 | : 54 | : "l"(&tensor_map), "r"(coordinate), 55 | "r"(static_cast<_CUDA_VSTD::uint32_t>( 56 | __cvta_generic_to_shared(tile_shared))) 57 | : "memory"); 58 | // 7. Wait for TMA transfer to have finished reading shared memory. 59 | // Create a "bulk async-group" out of the previous bulk copy operation. 60 | cde::cp_async_bulk_commit_group(); 61 | // Wait for the group to have completed reading from shared memory. 62 | cde::cp_async_bulk_wait_group_read<0>(); 63 | } 64 | 65 | __threadfence(); 66 | __syncthreads(); 67 | } 68 | 69 | int main() { 70 | // initialize array and fill it with values 71 | int h_data[array_size]; 72 | for (size_t i = 0; i < array_size; ++i) { 73 | h_data[i] = 2; 74 | } 75 | 76 | // print the array before the kernel 77 | // one tile per line 78 | print_matrix(h_data, array_size / tile_size, tile_size); 79 | 80 | // transfer array to device 81 | int *d_data = nullptr; 82 | cudaMalloc(&d_data, array_size * sizeof(int)); 83 | cudaMemcpy(d_data, h_data, array_size * sizeof(int), 84 | cudaMemcpyHostToDevice); 85 | 86 | // create tensor map 87 | CUtensorMap tensor_map = 88 | create_1d_tensor_map(array_size, tile_size, d_data); 89 | 90 | size_t offset = 91 | tile_size * 3; // select the second tile of the array to change 92 | kernel<<<1, 128>>>(tensor_map, offset); 93 | 94 | cuda_check_error(); 95 | 96 | cudaMemcpy(h_data, d_data, array_size * sizeof(int), 97 | cudaMemcpyDeviceToHost); 98 | cudaFree(d_data); 99 | 100 | // print the array after the kernel 101 | print_matrix(h_data, array_size / tile_size, tile_size); 102 | 103 | return 0; 104 | } 105 | -------------------------------------------------------------------------------- /examples/4_tma_1d.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code uses TMA's 1d tensor load to load 3 | a portion of an array to shared memory and then 4 | change the value in the shared memory and uses TMA's store 5 | to store the portion back to global memory. We print the result 6 | to show the changes are done. 7 | */ 8 | 9 | // supress warning about barrier in shared memory on line 32 10 | #pragma nv_diag_suppress static_var_with_dynamic_init 11 | 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | 19 | using barrier = cuda::barrier; 20 | namespace cde = cuda::device::experimental; 21 | 22 | const int array_size = 128; 23 | const int tile_size = 16; 24 | 25 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 26 | int coordinate) { 27 | // Shared memory buffers for tile. The destination shared memory buffer of 28 | // a bulk operations should be 16 byte aligned. 29 | __shared__ alignas(16) int tile_shared[tile_size]; 30 | 31 | // 1. a) Initialize shared memory barrier with the number of threads 32 | // participating in the barrier. 33 | // b) Make initialized barrier visible in async proxy. 34 | __shared__ barrier bar; 35 | if (threadIdx.x == 0) { 36 | init(&bar, blockDim.x); // a) 37 | cde::fence_proxy_async_shared_cta(); // b) 38 | } 39 | __syncthreads(); 40 | 41 | // 2. Initiate TMA transfer to copy global to shared memory. 42 | barrier::arrival_token token; 43 | if (threadIdx.x == 0) { 44 | cde::cp_async_bulk_tensor_1d_global_to_shared(tile_shared, &tensor_map, 45 | coordinate, bar); 46 | // 3a. Arrive on the barrier and tell how many bytes are expected to 47 | // come in (the transaction count) 48 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(tile_shared)); 49 | } else { 50 | // 3b. Rest of threads just arrive 51 | token = bar.arrive(); 52 | } 53 | 54 | // 3c. Wait for the data to have arrived. 55 | bar.wait(std::move(token)); 56 | 57 | // 4. change the value in shared memory 58 | for (int i = threadIdx.x; i < array_size; i += blockDim.x) { 59 | if (i < tile_size) { 60 | tile_shared[i] += 1; 61 | } 62 | } 63 | 64 | // 5. Wait for shared memory writes to be visible to TMA engine. 65 | cde::fence_proxy_async_shared_cta(); 66 | __syncthreads(); 67 | // After syncthreads, writes by all threads are visible to TMA engine. 68 | 69 | // 6. Initiate TMA transfer to copy shared memory to global memory 70 | if (threadIdx.x == 0) { 71 | cde::cp_async_bulk_tensor_1d_shared_to_global(&tensor_map, coordinate, 72 | tile_shared); 73 | // 7. Wait for TMA transfer to have finished reading shared memory. 74 | // Create a "bulk async-group" out of the previous bulk copy operation. 75 | cde::cp_async_bulk_commit_group(); 76 | // Wait for the group to have completed reading from shared memory. 77 | cde::cp_async_bulk_wait_group_read<0>(); 78 | } 79 | 80 | __threadfence(); 81 | __syncthreads(); 82 | } 83 | 84 | int main() { 85 | // initialize array and fill it with values 86 | int h_data[array_size]; 87 | for (size_t i = 0; i < array_size; ++i) { 88 | h_data[i] = i; 89 | } 90 | 91 | // print the array before the kernel 92 | // one tile per line 93 | print_matrix(h_data, array_size / tile_size, tile_size); 94 | 95 | // transfer array to device 96 | int *d_data = nullptr; 97 | cudaMalloc(&d_data, array_size * sizeof(int)); 98 | cudaMemcpy(d_data, h_data, array_size * sizeof(int), 99 | cudaMemcpyHostToDevice); 100 | 101 | // create tensor map 102 | CUtensorMap tensor_map = 103 | create_1d_tensor_map(array_size, tile_size, d_data); 104 | 105 | size_t offset = 106 | tile_size * 3; // select the second tile of the array to change 107 | kernel<<<1, 128>>>(tensor_map, offset); 108 | 109 | cuda_check_error(); 110 | 111 | cudaMemcpy(h_data, d_data, array_size * sizeof(int), 112 | cudaMemcpyDeviceToHost); 113 | cudaFree(d_data); 114 | 115 | // print the array after the kernel 116 | print_matrix(h_data, array_size / tile_size, tile_size); 117 | 118 | return 0; 119 | } 120 | -------------------------------------------------------------------------------- /examples/2_wgmma_dense.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "matrix_utilities.cuh" 15 | #include "profile_utilities.cuh" 16 | #include "wgmma.cuh" 17 | 18 | const int M = 64; 19 | const int N = 8; 20 | const int K = 16; 21 | 22 | const int threads_per_block = 32 * 4; // 4 warps 23 | const int blocks = 1; 24 | 25 | __global__ void kernel(half *A, half *B, half *C) { 26 | // metadata 27 | const int tid = threadIdx.x; 28 | const int warp_id = tid / 32; 29 | const int lane_id = tid % 32; 30 | const int group_id = lane_id >> 2; 31 | const int lane_in_group = lane_id & 3; 32 | 33 | __syncthreads(); 34 | 35 | __align__(16) __shared__ half A_shared[M * K]; 36 | __align__(16) __shared__ half B_shared[K * N]; 37 | 38 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async 39 | // 8x8 core blocks, we use one thread here to 40 | // easy demonstrate the required layout 41 | if (tid == 0) { 42 | for (int i = 0; i < M; i++) { 43 | for (int j = 0; j < K; j++) { 44 | int block_x = i / 8; 45 | int block_row = i % 8; 46 | int block_y = j / 8; 47 | int block_col = j % 8; 48 | int block_id = block_x * 2 + block_y; 49 | int offset = block_id * 64 + block_row * 8 + block_col; 50 | A_shared[offset] = A[i * K + j]; 51 | } 52 | } 53 | 54 | for (int i = 0; i < K; i++) { 55 | for (int j = 0; j < N; j++) { 56 | int block_x = i / 8; 57 | int block_row = i % 8; 58 | int block_y = j / 8; 59 | int block_col = j % 8; 60 | int block_id = block_x * 1 + block_y; 61 | int offset = block_id * 64 + block_row * 8 + block_col; 62 | B_shared[offset] = B[i * N + j]; 63 | } 64 | } 65 | } 66 | 67 | __syncthreads(); 68 | 69 | // create descriptors for the matrices 70 | GmmaDescriptor desc_a = make_desc_a(A_shared); 71 | GmmaDescriptor desc_b = make_desc_b(B_shared); 72 | 73 | // accumulator 74 | uint32_t c[2] = {}; 75 | 76 | // called whenever the accumulator is accessed 77 | warpgroup_arrive(); 78 | 79 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 80 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 81 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 82 | // imm-scale-a, imme-scale-b, imm-trans-b; 83 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 84 | "{%0, %1}, " // accumulator 85 | "%2, %3, " // matrix a descriptor 86 | "1, " // 0 => D = A*B, 1 => D = D + A*B 87 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 88 | // -1 to a or b 89 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 90 | : "+r"(c[0]), "+r"(c[1]) 91 | : "l"(desc_a), "l"(desc_b)); 92 | 93 | // commit, start the computation 94 | warpgroup_commit_batch(); 95 | 96 | // wait for the previous commit to finish 97 | warpgroup_wait<0>(); 98 | 99 | // thread fence needed for async operations 100 | __threadfence(); 101 | 102 | warpgroup_arrive(); 103 | 104 | uint32_t *C_ptr = reinterpret_cast(C); 105 | 106 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 107 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 108 | 109 | // write back to global memory 110 | C_ptr[offset1] = c[0]; 111 | C_ptr[offset2] = c[1]; 112 | } 113 | 114 | int main() { 115 | 116 | half *d_C; 117 | half h_C[M * N]; 118 | half h_CPU[M * N]; 119 | half h_A[M * K]; 120 | half h_B[K * N]; 121 | 122 | fill_fixed(h_C, M, N, 0); 123 | 124 | fill_random(h_A, M, K); 125 | fill_random(h_B, K, N); 126 | 127 | half *d_A, *d_B; 128 | 129 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 130 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 131 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 132 | 133 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 134 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 135 | 136 | kernel<<>>(d_A, d_B, d_C); 137 | 138 | cuda_check_error(); 139 | 140 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 141 | 142 | // print_matrix(h_C, M, N); 143 | 144 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 145 | 146 | compare_matrices(h_CPU, h_C, M, N); 147 | 148 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 149 | 150 | return 0; 151 | } 152 | -------------------------------------------------------------------------------- /headers/device/wgmma.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | __device__ void warpgroup_fence_operand(uint32_t ®) { 5 | asm volatile("" : "+r"(reg)::"memory"); 6 | } 7 | 8 | __device__ void warpgroup_arrive() { 9 | asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); 10 | } 11 | 12 | __device__ void warpgroup_commit_batch() { 13 | asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); 14 | } 15 | 16 | template 17 | __device__ void warpgroup_wait() { 18 | static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); 19 | asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); 20 | } 21 | 22 | // wgmma tensor descriptor 23 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor 24 | // taken from https://github.com/KnowingNothing/MatmulTutorial/blob/18366a51005c3b3395449d5eb5da02ec56198b65/examples/atom/single-wgmma-f8.cu#L169 25 | union GmmaDescriptor 26 | { 27 | __device__ constexpr GmmaDescriptor() noexcept : desc_(0) {} 28 | __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} 29 | __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept 30 | : desc_(t.desc_) {} 31 | __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept 32 | : desc_(t.desc_) {} 33 | 34 | __device__ constexpr GmmaDescriptor & 35 | operator=(GmmaDescriptor const &t) noexcept 36 | { 37 | desc_ = t.desc_; 38 | return *this; 39 | } 40 | 41 | __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept 42 | { 43 | desc_ = t.desc_; 44 | return *this; 45 | } 46 | 47 | uint64_t desc_; 48 | uint32_t reg32_[2]; 49 | uint16_t reg16_[4]; 50 | 51 | // Bitfield implementation avoids the need for shifts in assignment 52 | struct 53 | { 54 | // start_address, bit [0,14), 4LSB not included 55 | uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused 56 | // leading dimension byte offset, bit [16,30), 4LSB not included 57 | // For N: This is the stride from the first col to the second col of the 8x2 58 | // brick in INTERLEAVED 59 | // Unused for all SWIZZLE_* layouts (and assumed to be 1) 60 | // For T: This is the stride from the first 8 rows to the next 8 rows. 61 | uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused 62 | // stride dimension byte offset, bit [32,46), 4LSB not included 63 | // For N: This is the stride from the first 8 rows to the next 8 rows. 64 | // For T: This is the stride fro mthe first 8 cols to the next 8 cols. 65 | uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused 66 | // base_offset, bit [49,52) 67 | // Valid only for SWIZZLE_128B and SWIZZLE_64B 68 | uint8_t : 1, 69 | base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused 70 | // layout type, bit [62,64) 71 | // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 72 | uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) 73 | } bitfield; 74 | 75 | // Decay to a uint64_t 76 | __device__ constexpr operator uint64_t() const noexcept { return desc_; } 77 | 78 | // Printer 79 | // __device__ friend void print(GmmaDescriptor const& t) 80 | // { 81 | // #if !defined(__CUDACC_RTC__) 82 | // printf("GmmaDescriptor: 0x%016 %lli\n", static_cast(t.desc_)); printf(" start_addr : 0x%04x\n", 84 | // t.bitfield.start_address_); printf(" leading_off: 0x%04x (%d)\n", 85 | // t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); 86 | // printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, 87 | // t.bitfield.stride_byte_offset_); printf(" base_offset: 0x%01x\n", 88 | // t.bitfield.base_offset_); printf(" layout_type: 0x%01x (%s)\n", 89 | // t.bitfield.layout_type_, 90 | // to_string(static_cast(t.bitfield.layout_type_))); 91 | // #endif 92 | // } 93 | }; 94 | 95 | template 96 | __device__ GmmaDescriptor make_desc(PointerType smem_ptr) 97 | { 98 | GmmaDescriptor desc; 99 | uint32_t uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 100 | desc.bitfield.start_address_ = uint_ptr >> 4; 101 | desc.bitfield.layout_type_ = swizzle; // no swizzle 102 | desc.bitfield.leading_byte_offset_ = leading_offset; // 16 bytes 103 | desc.bitfield.stride_byte_offset_ = stride_offset; // 8 bytes 104 | /// base_offset_ is not valid for non-swizzle 105 | desc.bitfield.base_offset_ = 0; 106 | return desc; 107 | } -------------------------------------------------------------------------------- /headers/host/tma_tensor_map.cuh: -------------------------------------------------------------------------------- 1 | // apis for host code to initailize tensor map for tma apis 2 | 3 | #include // PFN_cuTensorMapEncodeTiled, CUtensorMap 4 | #include 5 | #include 6 | 7 | PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() 8 | { 9 | void *driver_ptr = nullptr; 10 | cudaDriverEntryPointQueryResult driver_status; 11 | auto code = cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status); 12 | assert(code == cudaSuccess && "Could not get driver API"); 13 | return reinterpret_cast(driver_ptr); 14 | } 15 | 16 | // create a 1d tensor map 17 | CUtensorMap create_1d_tensor_map(uint64_t tensor_dim, uint32_t tile_dim, void *tensor_ptr) 18 | { 19 | // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html 20 | CUtensorMap local_tensor_map{}; 21 | // rank is the number of dimensions of the array. 22 | constexpr uint32_t rank = 1; 23 | uint64_t size[rank] = {tensor_dim}; 24 | // The stride is the number of bytes to traverse from the first element of one row to the next. 25 | // It must be a multiple of 16. 26 | uint64_t stride[rank] = {tensor_dim * sizeof(int)}; 27 | // The box_size is the size of the shared memory buffer that is used as the 28 | // destination of a TMA transfer. 29 | uint32_t box_size[rank] = {tile_dim}; 30 | // The distance between elements in units of sizeof(element). A stride of 2 31 | // can be used to load only the real component of a complex-valued tensor, for instance. 32 | uint32_t elem_stride[rank] = {1}; 33 | 34 | // Get a function pointer to the cuTensorMapEncodeTiled driver API. 35 | auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); 36 | 37 | // Create the tensor descriptor. 38 | CUresult res = cuTensorMapEncodeTiled( 39 | &local_tensor_map, // CUtensorMap *tensorMap, 40 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT32, 41 | rank, // cuuint32_t tensorRank, 42 | tensor_ptr, // void *globalAddress, 43 | size, // const cuuint64_t *globalDim, 44 | stride, // const cuuint64_t *globalStrides, 45 | box_size, // const cuuint32_t *boxDim, 46 | elem_stride, // const cuuint32_t *elementStrides, 47 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 48 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, 49 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 50 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 51 | 52 | assert(res == CUDA_SUCCESS && "tensormap creation failed."); 53 | 54 | return local_tensor_map; 55 | } 56 | 57 | 58 | // create a 2d tensor map 59 | // for a matrix, row number is tensor_dim1, column number is tensor_dim2 60 | // assuming row major 61 | template 62 | CUtensorMap create_2d_tensor_map(uint64_t tensor_dim1, uint64_t tensor_dim2, uint32_t tile_dim1, uint32_t tile_dim2, void *tensor_ptr) 63 | { 64 | // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html 65 | CUtensorMap local_tensor_map{}; 66 | // rank is the number of dimensions of the array. 67 | constexpr uint32_t rank = 2; 68 | uint64_t size[rank] = {tensor_dim2, tensor_dim1}; 69 | // The stride is the number of bytes to traverse from the first element of one row to the next. 70 | // It must be a multiple of 16. 71 | uint64_t stride[rank - 1] = {tensor_dim2 * sizeof(T)}; 72 | // The box_size is the size of the shared memory buffer that is used as the 73 | // destination of a TMA transfer. 74 | uint32_t box_size[rank] = {tile_dim2, tile_dim1}; 75 | // The distance between elements in units of sizeof(element). A stride of 2 76 | // can be used to load only the real component of a complex-valued tensor, for instance. 77 | uint32_t elem_stride[rank] = {1, 1}; 78 | 79 | // Get a function pointer to the cuTensorMapEncodeTiled driver API. 80 | auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); 81 | 82 | // Create the tensor descriptor. 83 | CUresult res = cuTensorMapEncodeTiled( 84 | &local_tensor_map, // CUtensorMap *tensorMap, 85 | type, 86 | rank, // cuuint32_t tensorRank, 87 | tensor_ptr, // void *globalAddress, 88 | size, // const cuuint64_t *globalDim, 89 | stride, // const cuuint64_t *globalStrides, 90 | box_size, // const cuuint32_t *boxDim, 91 | elem_stride, // const cuuint32_t *elementStrides, 92 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 93 | swizzle, 94 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 95 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 96 | 97 | assert(res == CUDA_SUCCESS && "tensormap creation failed."); 98 | 99 | return local_tensor_map; 100 | } -------------------------------------------------------------------------------- /examples/10_tma_and_wgmma.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 8; 28 | const int K = 16; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | // Load A 57 | uint64_t token; 58 | if (tid == 0) { 59 | // call the loading api 60 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map, 0, 61 | 0, bar); 62 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 63 | 0, 0, bar); 64 | token = cuda::device::barrier_arrive_tx( 65 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 66 | } else { 67 | token = bar.arrive(); 68 | } 69 | 70 | bar.wait(cuda::std::move(token)); 71 | 72 | __syncthreads(); 73 | 74 | // create descriptors for the matrices 75 | GmmaDescriptor desc_a = make_desc_a(A_shared); 76 | GmmaDescriptor desc_b = make_desc_b(B_shared); 77 | 78 | // accumulator 79 | uint32_t c[2] = {}; 80 | 81 | // called whenever the accumulator is accessed 82 | warpgroup_arrive(); 83 | 84 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 85 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 86 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 87 | // imm-scale-a, imme-scale-b, imm-trans-b; 88 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 89 | "{%0, %1}, " // accumulator 90 | "%2, %3, " // matrix a descriptor 91 | "1, " // 0 => D = A*B, 1 => D = D + A*B 92 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 93 | // -1 to a or b 94 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 95 | : "+r"(c[0]), "+r"(c[1]) 96 | : "l"(desc_a), "l"(desc_b)); 97 | 98 | // commit, start the computation 99 | warpgroup_commit_batch(); 100 | 101 | // wait for the previous commit to finish 102 | warpgroup_wait<0>(); 103 | 104 | // thread fence needed for async operations 105 | __threadfence(); 106 | 107 | warpgroup_arrive(); 108 | 109 | uint32_t *C_ptr = reinterpret_cast(C); 110 | 111 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 112 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 113 | 114 | // write back to global memory 115 | C_ptr[offset1] = c[0]; 116 | C_ptr[offset2] = c[1]; 117 | } 118 | 119 | int main() { 120 | 121 | half *d_C; 122 | half h_C[M * N]; 123 | half h_CPU[M * N]; 124 | half h_A[M * K]; 125 | half h_B[K * N]; 126 | 127 | fill_fixed(h_C, M, N, 0); 128 | 129 | fill_random(h_A, M, K); 130 | fill_random(h_B, K, N); 131 | 132 | half *d_A, *d_B; 133 | 134 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 135 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 136 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 137 | 138 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 139 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 140 | 141 | CUtensorMap tensor_map = create_2d_tensor_map_half<1>(M, K, M, K, d_A); 142 | CUtensorMap tensor_map_b = create_2d_tensor_map_half<0>(K, N, K, N, d_B); 143 | 144 | kernel<<>>(tensor_map, tensor_map_b, d_C); 145 | 146 | cuda_check_error(); 147 | 148 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 149 | 150 | // print_matrix(h_C, M, N); 151 | 152 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 153 | 154 | compare_matrices(h_CPU, h_C, M, N); 155 | 156 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 157 | 158 | return 0; 159 | } 160 | -------------------------------------------------------------------------------- /examples/test.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 16; 28 | const int K = 16; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | // Load A 57 | uint64_t token; 58 | if (tid == 0) { 59 | // call the loading api 60 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map, 0, 61 | 0, bar); 62 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 63 | 0, 0, bar); 64 | token = cuda::device::barrier_arrive_tx( 65 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 66 | } else { 67 | token = bar.arrive(); 68 | } 69 | 70 | bar.wait(cuda::std::move(token)); 71 | 72 | __syncthreads(); 73 | 74 | // create descriptors for the matrices 75 | GmmaDescriptor desc_a = make_desc_a(A_shared); 76 | GmmaDescriptor desc_b = make_desc_b(B_shared); 77 | 78 | // accumulator 79 | uint32_t c[4] = {}; 80 | 81 | // called whenever the accumulator is accessed 82 | warpgroup_arrive(); 83 | 84 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 85 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 86 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 87 | // imm-scale-a, imme-scale-b, imm-trans-b; 88 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 89 | "{%0, %1, %2, %3}, " // accumulator 90 | "%4, %5, " // matrix a descriptor 91 | "1, " // 0 => D = A*B, 1 => D = D + A*B 92 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 93 | // -1 to a or b 94 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 95 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 96 | : "l"(desc_a), "l"(desc_b)); 97 | 98 | // commit, start the computation 99 | warpgroup_commit_batch(); 100 | 101 | // wait for the previous commit to finish 102 | warpgroup_wait<0>(); 103 | 104 | // thread fence needed for async operations 105 | __threadfence(); 106 | 107 | warpgroup_arrive(); 108 | 109 | uint32_t *C_ptr = reinterpret_cast(C); 110 | 111 | int offset1 = warp_id * 16 * 8 + group_id * 8 + lane_in_group; 112 | int offset2 = warp_id * 16 * 8 + (group_id + 8) * 8 + lane_in_group; 113 | 114 | // write back to global memory 115 | C_ptr[offset1] = c[0]; 116 | C_ptr[offset2] = c[1]; 117 | C_ptr[offset1 + 4] = c[2]; 118 | C_ptr[offset2 + 4] = c[3]; 119 | } 120 | 121 | int main() { 122 | 123 | half *d_C; 124 | half h_C[M * N]; 125 | half h_CPU[M * N]; 126 | half h_A[M * K]; 127 | half h_B[K * N]; 128 | 129 | fill_fixed(h_C, M, N, 0); 130 | 131 | fill_random(h_A, M, K); 132 | // fill_tilewise(h_A, M, K, 8, 8); 133 | // fill_fixed(h_B, K, N, 1); 134 | fill_random(h_B, K, N); 135 | 136 | half *d_A, *d_B; 137 | 138 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 139 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 140 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 141 | 142 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 143 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 144 | 145 | CUtensorMap tensor_map = create_2d_tensor_map_half<1>(M, K, M, K, d_A); 146 | CUtensorMap tensor_map_b = create_2d_tensor_map_half<1>(K, N, K, N, d_B); 147 | 148 | kernel<<>>(tensor_map, tensor_map_b, d_C); 149 | 150 | cuda_check_error(); 151 | 152 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 153 | 154 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 155 | 156 | print_differnce(h_C, h_CPU, M, N, 0.0f); 157 | 158 | print_matrix(h_C, M, N); 159 | 160 | compare_matrices(h_CPU, h_C, M, N); 161 | 162 | return 0; 163 | } 164 | -------------------------------------------------------------------------------- /examples/6_multicast.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code uses TMA's 1d tensor load to load 3 | a portion of an array to shared memory and then 4 | change the value in the shared memory and uses TMA's store 5 | to store the portion back to global memory. We print the result 6 | to show the changes are done. 7 | */ 8 | 9 | // supress warning about barrier in shared memory on line 32 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma.cuh" 18 | #include "tma_tensor_map.cuh" 19 | 20 | using barrier = cuda::barrier; 21 | namespace cde = cuda::device::experimental; 22 | 23 | namespace cg = cooperative_groups; 24 | 25 | const int array_size = 128; 26 | const int tile_size = 16; 27 | const int cluster_size = 4; // we use 4 blocks in a cluster 28 | 29 | __global__ void __cluster_dims__(cluster_size, 1, 1) 30 | kernel(const __grid_constant__ CUtensorMap tensor_map, int coordinate, 31 | int *result) { 32 | // cluster metadata 33 | cg::cluster_group cluster = cg::this_cluster(); 34 | unsigned int clusterBlockRank = cluster.block_rank(); 35 | 36 | __shared__ alignas(16) int tile_shared[tile_size]; 37 | 38 | // we let the first block in the cluster to load a 39 | // tile to the shared memory of all 4 blocks 40 | if (clusterBlockRank == 0) { 41 | __shared__ barrier bar; 42 | 43 | if (threadIdx.x == 0) { 44 | init(&bar, blockDim.x); 45 | cde::fence_proxy_async_shared_cta(); 46 | } 47 | __syncthreads(); 48 | 49 | barrier::arrival_token token; 50 | if (threadIdx.x == 0) { 51 | /* 52 | each bit represents a block in the cluster, starting from the least 53 | significant bit (the right side) 54 | 55 | here we use block mask 1011, which means 56 | blocks 0, 1, and 3 will recieve the data from multicast 57 | whereas block 2 will not 58 | 59 | we will verify this by printing the result 60 | */ 61 | uint16_t ctaMask = 0b1011; 62 | asm volatile( 63 | "cp.async.bulk.tensor.1d.shared::cluster.global.tile.mbarrier::" 64 | "complete_tx::bytes.multicast::cluster " 65 | "[%0], [%1, {%2}], [%3], %4;\n" 66 | : 67 | : "r"(static_cast<_CUDA_VSTD::uint32_t>( 68 | __cvta_generic_to_shared(tile_shared))), 69 | "l"(&tensor_map), "r"(coordinate), 70 | "r"(static_cast<_CUDA_VSTD::uint32_t>( 71 | __cvta_generic_to_shared( 72 | ::cuda::device::barrier_native_handle(bar)))), 73 | "h"(ctaMask) 74 | : "memory"); 75 | 76 | token = 77 | cuda::device::barrier_arrive_tx(bar, 1, sizeof(tile_shared)); 78 | } else { 79 | token = bar.arrive(); 80 | } 81 | 82 | bar.wait(std::move(token)); 83 | } 84 | 85 | // rest of the clusters needs to wait for cluster 0 to load the data 86 | cluster.sync(); 87 | 88 | // put the results back 89 | if (clusterBlockRank == 0 && threadIdx.x == 0) { 90 | for (int i = 0; i < tile_size; ++i) { 91 | result[clusterBlockRank * tile_size + i] = tile_shared[i]; 92 | } 93 | } 94 | 95 | if (clusterBlockRank == 1 && threadIdx.x == 0) { 96 | for (int i = 0; i < tile_size; ++i) { 97 | result[clusterBlockRank * tile_size + i] = tile_shared[i]; 98 | } 99 | } 100 | 101 | if (clusterBlockRank == 2 && threadIdx.x == 0) { 102 | for (int i = 0; i < tile_size; ++i) { 103 | result[clusterBlockRank * tile_size + i] = tile_shared[i]; 104 | } 105 | } 106 | 107 | if (clusterBlockRank == 3 && threadIdx.x == 0) { 108 | for (int i = 0; i < tile_size; ++i) { 109 | result[clusterBlockRank * tile_size + i] = tile_shared[i]; 110 | } 111 | } 112 | } 113 | 114 | int main() { 115 | // initialize array and fill it with values 116 | int h_data[array_size]; 117 | for (size_t i = 0; i < array_size; ++i) { 118 | h_data[i] = i; 119 | } 120 | 121 | // print the array before the kernel 122 | // one tile per line 123 | print_matrix(h_data, array_size / tile_size, tile_size); 124 | 125 | // transfer array to device 126 | int *d_data = nullptr; 127 | cudaMalloc(&d_data, array_size * sizeof(int)); 128 | cudaMemcpy(d_data, h_data, array_size * sizeof(int), 129 | cudaMemcpyHostToDevice); 130 | 131 | // create tensor map 132 | CUtensorMap tensor_map = 133 | create_1d_tensor_map(array_size, tile_size, d_data); 134 | 135 | // a 2d array that will be used to store the tile loaded to each block 136 | int *d_result = nullptr; 137 | cudaMalloc(&d_result, tile_size * cluster_size * sizeof(int)); 138 | 139 | size_t offset = 140 | tile_size * 3; // select the second tile of the array to change 141 | kernel<<>>(tensor_map, offset, d_result); 142 | 143 | cuda_check_error(); 144 | 145 | // transfer the result back to host 146 | int h_result[tile_size * cluster_size]; 147 | cudaMemcpy(h_result, d_result, tile_size * cluster_size * sizeof(int), 148 | cudaMemcpyDeviceToHost); 149 | 150 | // print the result for each block 151 | print_matrix(h_result, cluster_size, tile_size); 152 | 153 | cudaFree(d_data); 154 | 155 | return 0; 156 | } 157 | -------------------------------------------------------------------------------- /examples/3_wgmma_sparse.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "matrix_utilities.cuh" 17 | #include "profile_utilities.cuh" 18 | #include "wgmma.cuh" 19 | 20 | const int M = 64; 21 | const int N = 8; 22 | const int K = 32; 23 | 24 | // 2:4 format 25 | const int K2 = 16; 26 | 27 | __global__ void kernel(half *A, half *B, half *C, u_int32_t *metadata_array) { 28 | const int tid = threadIdx.x; 29 | const int warp_id = tid / 32; 30 | const int lane_id = tid % 32; 31 | const int group_id = lane_id >> 2; 32 | const int lane_in_group = lane_id & 3; 33 | const int lane_in_work_group = lane_in_group % 2; 34 | 35 | __align__(16) __shared__ half A_shared[M * K2]; 36 | __align__(16) __shared__ half B_shared[K * N]; 37 | 38 | // use one thread to load so it's easier to tell the layout 39 | // refer to the ptx menu for the layout of the shared memory 40 | if (tid == 0) { 41 | for (int i = 0; i < M; i++) { 42 | for (int j = 0; j < K2; j++) { 43 | int block_x = i / 8; 44 | int block_row = i % 8; 45 | int block_y = j / 8; 46 | int block_col = j % 8; 47 | int block_id = block_x * 2 + block_y; 48 | int offset = block_id * 64 + block_row * 8 + block_col; 49 | A_shared[offset] = A[i * K2 + j]; 50 | } 51 | } 52 | 53 | for (int i = 0; i < K; i++) { 54 | for (int j = 0; j < N; j++) { 55 | int block_x = i / 8; 56 | int block_row = i % 8; 57 | int block_y = j / 8; 58 | int block_col = j % 8; 59 | int block_id = block_x * 1 + block_y; 60 | int offset = block_id * 64 + block_row * 8 + block_col; 61 | B_shared[offset] = B[i * N + j]; 62 | } 63 | } 64 | } 65 | 66 | __syncthreads(); 67 | 68 | // load metadata 69 | u_int32_t metadata; 70 | uint metadata_offset = warp_id * 16 + lane_in_work_group * 8 + group_id; 71 | metadata = metadata_array[metadata_offset]; 72 | 73 | __syncthreads(); 74 | 75 | // create descriptors 76 | GmmaDescriptor desc_a = make_desc_a(A_shared); 77 | GmmaDescriptor desc_b = make_desc_b(B_shared); 78 | 79 | // accumulator 80 | uint32_t c[2] = {}; 81 | 82 | warpgroup_arrive(); 83 | 84 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " 85 | "{%0, %1}, " // c 86 | "%2, %3, " // desc A, B 87 | "%4, " // meta 88 | "0, " // thread selection 89 | "1, " // scale D 90 | "%7, %8, " // +/- scale A, B 91 | "%9, %10;" // transpose A, B 92 | : "+r"(c[0]), "+r"(c[1]) 93 | : "l"(desc_a), "l"(desc_b), 94 | "r"(metadata), // metadata 95 | "r"(0), // thread selection 96 | "r"(1), // scale D 97 | "n"(1), "n"(1), // +- scale A, B 98 | "n"(0), "n"(1)); // transpose A, B 99 | 100 | // commit, start the computation 101 | warpgroup_commit_batch(); 102 | 103 | // wait for the previous commit to finish 104 | warpgroup_wait<0>(); 105 | 106 | // thread fence needed for async operations 107 | __threadfence(); 108 | 109 | warpgroup_arrive(); 110 | 111 | // store the result 112 | uint32_t *C_ptr = reinterpret_cast(C); 113 | 114 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 115 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 116 | 117 | C_ptr[offset1] = c[0]; 118 | C_ptr[offset2] = c[1]; 119 | } 120 | 121 | int main() { 122 | 123 | half *d_C; 124 | half h_C[M * N]; 125 | half h_CPU[M * N]; 126 | half h_A[M * K]; 127 | half h_A2[M * K2]; 128 | half h_B[K * N]; 129 | 130 | fill_24(h_A, M, K); 131 | fill_random(h_B, K, N); 132 | 133 | // print_matrix(h_A, M, K); 134 | 135 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 136 | compress24(h_A, h_A2, M, K); 137 | 138 | // print_matrix(h_A2, M, K2); 139 | 140 | half *d_A, *d_B; 141 | 142 | cudaMalloc((void **)&d_A, M * K2 * sizeof(half)); 143 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 144 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 145 | 146 | cudaMemcpy(d_A, h_A2, M * K2 * sizeof(half), cudaMemcpyHostToDevice); 147 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 148 | 149 | int metadata_size = (M / 16) * (K / 16) * 8; 150 | 151 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 152 | inspect_metadata(h_A, metadata_array, M, K); 153 | 154 | u_int32_t *d_metadata; 155 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 156 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 157 | cudaMemcpyHostToDevice); 158 | 159 | kernel<<<1, 128>>>(d_A, d_B, d_C, d_metadata); 160 | 161 | cuda_check_error(); 162 | 163 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 164 | 165 | // print_matrix(h_C, M, N); 166 | 167 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 168 | 169 | compare_matrices(h_CPU, h_C, M, N); 170 | 171 | // print_differnce(h_CPU, h_C, M, N, 0); 172 | 173 | return 0; 174 | } 175 | -------------------------------------------------------------------------------- /dense/2_m64_n16_k16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 16; 28 | const int K = 16; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map_a, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | // Load A 57 | uint64_t token; 58 | if (tid == 0) { 59 | // call the loading api 60 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 61 | 0, bar); 62 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 63 | 0, 0, bar); 64 | token = cuda::device::barrier_arrive_tx( 65 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 66 | } else { 67 | token = bar.arrive(); 68 | } 69 | 70 | bar.wait(cuda::std::move(token)); 71 | 72 | __syncthreads(); 73 | 74 | // create descriptors for the matrices 75 | GmmaDescriptor desc_a = make_desc(A_shared); 76 | GmmaDescriptor desc_b = make_desc(B_shared); 77 | 78 | // accumulator 79 | uint32_t c[4] = {}; 80 | 81 | // called whenever the accumulator is accessed 82 | warpgroup_arrive(); 83 | 84 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 85 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 86 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 87 | // imm-scale-a, imme-scale-b, imm-trans-b; 88 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 89 | "{%0, %1, %2, %3}, " // accumulator 90 | "%4, %5, " // matrix a descriptor 91 | "1, " // 0 => D = A*B, 1 => D = D + A*B 92 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 93 | // -1 to a or b 94 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 95 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 96 | : "l"(desc_a), "l"(desc_b)); 97 | 98 | // commit, start the computation 99 | warpgroup_commit_batch(); 100 | 101 | // wait for the previous commit to finish 102 | warpgroup_wait<0>(); 103 | 104 | // thread fence needed for async operations 105 | __threadfence(); 106 | 107 | warpgroup_arrive(); 108 | 109 | uint32_t *C_ptr = reinterpret_cast(C); 110 | 111 | int offset1 = warp_id * 16 * 8 + group_id * 8 + lane_in_group; 112 | int offset2 = warp_id * 16 * 8 + (group_id + 8) * 8 + lane_in_group; 113 | 114 | // write back to global memory 115 | C_ptr[offset1] = c[0]; 116 | C_ptr[offset2] = c[1]; 117 | C_ptr[offset1 + 4] = c[2]; 118 | C_ptr[offset2 + 4] = c[3]; 119 | } 120 | 121 | int main() { 122 | 123 | half *d_C; 124 | half h_C[M * N]; 125 | half h_CPU[M * N]; 126 | half h_A[M * K]; 127 | half h_B[K * N]; 128 | 129 | fill_fixed(h_C, M, N, 0); 130 | 131 | fill_random(h_A, M, K); 132 | // fill_tilewise(h_A, M, K, 8, 8); 133 | // fill_fixed(h_B, K, N, 1); 134 | fill_random(h_B, K, N); 135 | 136 | half *d_A, *d_B; 137 | 138 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 139 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 140 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 141 | 142 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 143 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 144 | 145 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K, M, K, d_A); 146 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 147 | 148 | kernel<<>>(tensor_map_a, tensor_map_b, d_C); 149 | 150 | cuda_check_error(); 151 | 152 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 153 | 154 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 155 | 156 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 157 | 158 | // print_matrix(h_C, M, N); 159 | 160 | compare_matrices(h_CPU, h_C, M, N); 161 | 162 | return 0; 163 | } -------------------------------------------------------------------------------- /examples/9_swizzle.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 8; 28 | const int K = 16; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, half *B, 34 | half *C) { 35 | // metadata 36 | const int tid = threadIdx.x; 37 | const int warp_id = tid / 32; 38 | const int lane_id = tid % 32; 39 | const int group_id = lane_id >> 2; 40 | const int lane_in_group = lane_id & 3; 41 | 42 | __syncthreads(); 43 | 44 | __align__(128) __shared__ half A_shared[M * K]; 45 | __align__(16) __shared__ half B_shared[K * N]; 46 | 47 | __shared__ barrier bar; 48 | 49 | if (threadIdx.x == 0) { 50 | init(&bar, blockDim.x); 51 | } 52 | __syncthreads(); 53 | 54 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async 55 | // 8x8 core blocks, we use one thread here to 56 | // easy demonstrate the required layout 57 | if (tid == 0) { 58 | // load B 59 | for (int i = 0; i < K; i++) { 60 | for (int j = 0; j < N; j++) { 61 | int block_x = i / 8; 62 | int block_row = i % 8; 63 | int block_y = j / 8; 64 | int block_col = j % 8; 65 | int block_id = block_x * 1 + block_y; 66 | int offset = block_id * 64 + block_row * 8 + block_col; 67 | B_shared[offset] = B[i * N + j]; 68 | } 69 | } 70 | } 71 | 72 | // Load A 73 | uint64_t token; 74 | if (tid == 0) { 75 | // call the loading api 76 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map, 0, 77 | 0, bar); 78 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared)); 79 | } else { 80 | token = bar.arrive(); 81 | } 82 | 83 | bar.wait(cuda::std::move(token)); 84 | 85 | __syncthreads(); 86 | 87 | // create descriptors for the matrices 88 | GmmaDescriptor desc_a = make_desc_a(A_shared); 89 | GmmaDescriptor desc_b = make_desc_b(B_shared); 90 | 91 | // accumulator 92 | uint32_t c[2] = {}; 93 | 94 | // called whenever the accumulator is accessed 95 | warpgroup_arrive(); 96 | 97 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 98 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 99 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 100 | // imm-scale-a, imme-scale-b, imm-trans-b; 101 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 102 | "{%0, %1}, " // accumulator 103 | "%2, %3, " // matrix a descriptor 104 | "1, " // 0 => D = A*B, 1 => D = D + A*B 105 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 106 | // -1 to a or b 107 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 108 | : "+r"(c[0]), "+r"(c[1]) 109 | : "l"(desc_a), "l"(desc_b)); 110 | 111 | // commit, start the computation 112 | warpgroup_commit_batch(); 113 | 114 | // wait for the previous commit to finish 115 | warpgroup_wait<0>(); 116 | 117 | // thread fence needed for async operations 118 | __threadfence(); 119 | 120 | warpgroup_arrive(); 121 | 122 | uint32_t *C_ptr = reinterpret_cast(C); 123 | 124 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 125 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 126 | 127 | // write back to global memory 128 | C_ptr[offset1] = c[0]; 129 | C_ptr[offset2] = c[1]; 130 | } 131 | 132 | int main() { 133 | 134 | half *d_C; 135 | half h_C[M * N]; 136 | half h_CPU[M * N]; 137 | half h_A[M * K]; 138 | half h_B[K * N]; 139 | 140 | fill_fixed(h_C, M, N, 0); 141 | 142 | fill_random(h_A, M, K); 143 | fill_random(h_B, K, N); 144 | 145 | half *d_A, *d_B; 146 | 147 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 148 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 149 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 150 | 151 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 152 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 153 | 154 | CUtensorMap tensor_map = create_2d_tensor_map_half<1>(M, K, M, K, d_A); 155 | 156 | kernel<<>>(tensor_map, d_B, d_C); 157 | 158 | cuda_check_error(); 159 | 160 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 161 | 162 | // print_matrix(h_C, M, N); 163 | 164 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 165 | 166 | compare_matrices(h_CPU, h_C, M, N); 167 | 168 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 169 | 170 | return 0; 171 | } 172 | -------------------------------------------------------------------------------- /dense/1_m64_n8_k32.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 8; 28 | const int K = 32; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map_a, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | // Load A 57 | uint64_t token; 58 | if (tid == 0) { 59 | // call the loading api 60 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 61 | 0, bar); 62 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 63 | 0, 0, bar); 64 | token = cuda::device::barrier_arrive_tx( 65 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 66 | } else { 67 | token = bar.arrive(); 68 | } 69 | 70 | bar.wait(cuda::std::move(token)); 71 | 72 | __syncthreads(); 73 | 74 | // create descriptors for the matrices 75 | GmmaDescriptor desc_a = make_desc(A_shared); 76 | GmmaDescriptor desc_b = make_desc(B_shared); 77 | 78 | // accumulator 79 | uint32_t c[2] = {}; 80 | 81 | // called whenever the accumulator is accessed 82 | warpgroup_arrive(); 83 | 84 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 85 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 86 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 87 | // imm-scale-a, imme-scale-b, imm-trans-b; 88 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 89 | "{%0, %1}, " // accumulator 90 | "%2, %3, " // matrix a descriptor 91 | "1, " // 0 => D = A*B, 1 => D = D + A*B 92 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 93 | // -1 to a or b 94 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 95 | : "+r"(c[0]), "+r"(c[1]) 96 | : "l"(desc_a), "l"(desc_b)); 97 | 98 | // second step 99 | desc_a = make_desc(A_shared + 16); 100 | desc_b = make_desc(B_shared + 16 * 8); 101 | 102 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 103 | "{%0, %1}, " // accumulator 104 | "%2, %3, " // matrix a descriptor 105 | "1, " // 0 => D = A*B, 1 => D = D + A*B 106 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 107 | // -1 to a or b 108 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 109 | : "+r"(c[0]), "+r"(c[1]) 110 | : "l"(desc_a), "l"(desc_b)); 111 | 112 | warpgroup_arrive(); 113 | 114 | // commit, start the computation 115 | warpgroup_commit_batch(); 116 | 117 | // wait for the previous commit to finish 118 | warpgroup_wait<0>(); 119 | 120 | // thread fence needed for async operations 121 | __threadfence(); 122 | 123 | warpgroup_arrive(); 124 | 125 | uint32_t *C_ptr = reinterpret_cast(C); 126 | 127 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 128 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 129 | 130 | // write back to global memory 131 | C_ptr[offset1] = c[0]; 132 | C_ptr[offset2] = c[1]; 133 | } 134 | 135 | int main() { 136 | 137 | half *d_C; 138 | half h_C[M * N]; 139 | half h_CPU[M * N]; 140 | half h_A[M * K]; 141 | half h_B[K * N]; 142 | 143 | fill_fixed(h_C, M, N, 0); 144 | 145 | fill_random(h_A, M, K); 146 | fill_random(h_B, K, N); 147 | 148 | half *d_A, *d_B; 149 | 150 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 151 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 152 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 153 | 154 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 155 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 156 | 157 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K, M, K, d_A); 158 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 159 | 160 | kernel<<>>(tensor_map_a, tensor_map_b, d_C); 161 | 162 | cuda_check_error(); 163 | 164 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 165 | 166 | // print_matrix(h_C, M, N); 167 | 168 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 169 | 170 | compare_matrices(h_CPU, h_C, M, N); 171 | 172 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 173 | 174 | return 0; 175 | } -------------------------------------------------------------------------------- /sparse/1_m64_n8_k32.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix_utilities.cuh" 18 | #include "profile_utilities.cuh" 19 | #include "tma_tensor_map.cuh" 20 | #include "wgmma.cuh" 21 | 22 | #pragma nv_diag_suppress static_var_with_dynamic_init 23 | 24 | const int M = 64; 25 | const int N = 8; 26 | const int K = 32; 27 | 28 | // 2:4 format 29 | const int K_A = 16; 30 | 31 | const int threads_per_block = 32 * 4; // 4 warps 32 | const int blocks = 1; 33 | 34 | using barrier = cuda::barrier; 35 | namespace cde = cuda::device::experimental; 36 | 37 | __global__ void kernel( 38 | const __grid_constant__ CUtensorMap tensor_map_a, 39 | const __grid_constant__ CUtensorMap tensor_map_b, 40 | half *C, 41 | u_int32_t *metadata_array) { 42 | 43 | const int tid = threadIdx.x; 44 | const int warp_id = tid / 32; 45 | const int lane_id = tid % 32; 46 | const int group_id = lane_id >> 2; 47 | const int lane_in_group = lane_id & 3; 48 | const int lane_in_work_group = lane_in_group % 2; 49 | 50 | __align__(128) __shared__ half A_shared[M * K_A]; 51 | __align__(16) __shared__ half B_shared[K * N]; 52 | 53 | __shared__ barrier bar; 54 | 55 | if (threadIdx.x == 0) { 56 | init(&bar, blockDim.x); 57 | } 58 | __syncthreads(); 59 | 60 | uint64_t token; 61 | if (tid == 0) { 62 | // call the loading api 63 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 64 | 0, bar); 65 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 66 | 0, 0, bar); 67 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared) + sizeof(B_shared)); 68 | } else { 69 | token = bar.arrive(); 70 | } 71 | 72 | bar.wait(cuda::std::move(token)); 73 | 74 | __syncthreads(); 75 | 76 | // load metadata 77 | u_int32_t metadata; 78 | uint metadata_offset = warp_id * 16 + lane_in_work_group * 8 + group_id; 79 | metadata = metadata_array[metadata_offset]; 80 | 81 | __syncthreads(); 82 | 83 | // create descriptors 84 | GmmaDescriptor desc_a = make_desc(A_shared); 85 | GmmaDescriptor desc_b = make_desc(B_shared); 86 | 87 | // accumulator 88 | uint32_t c[2] = {}; 89 | 90 | warpgroup_arrive(); 91 | 92 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " 93 | "{%0, %1}, " // c 94 | "%2, %3, " // desc A, B 95 | "%4, " // meta 96 | "0, " // thread selection 97 | "1, " // scale D 98 | "%7, %8, " // +/- scale A, B 99 | "%9, %10;" // transpose A, B 100 | : "+r"(c[0]), "+r"(c[1]) 101 | : "l"(desc_a), "l"(desc_b), 102 | "r"(metadata), // metadata 103 | "r"(0), // thread selection 104 | "r"(1), // scale D 105 | "n"(1), "n"(1), // +- scale A, B 106 | "n"(0), "n"(1)); // transpose A, B 107 | 108 | // commit, start the computation 109 | warpgroup_commit_batch(); 110 | 111 | // wait for the previous commit to finish 112 | warpgroup_wait<0>(); 113 | 114 | // thread fence needed for async operations 115 | __threadfence(); 116 | 117 | warpgroup_arrive(); 118 | 119 | // store the result 120 | uint32_t *C_ptr = reinterpret_cast(C); 121 | 122 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 123 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 124 | 125 | C_ptr[offset1] = c[0]; 126 | C_ptr[offset2] = c[1]; 127 | } 128 | 129 | int main() { 130 | 131 | half *d_C; 132 | half h_C[M * N]; 133 | half h_CPU[M * N]; 134 | half h_A[M * K]; 135 | half h_A2[M * K_A]; 136 | half h_B[K * N]; 137 | 138 | fill_24(h_A, M, K); 139 | fill_random(h_B, K, N); 140 | 141 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 142 | compress24(h_A, h_A2, M, K); 143 | 144 | // print_matrix(h_A2, M, K_A); 145 | 146 | half *d_A, *d_B; 147 | 148 | cudaMalloc((void **)&d_A, M * K_A * sizeof(half)); 149 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 150 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 151 | 152 | cudaMemcpy(d_A, h_A2, M * K_A * sizeof(half), cudaMemcpyHostToDevice); 153 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 154 | 155 | int metadata_size = (M / 16) * (K / 16) * 8; 156 | 157 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 158 | inspect_metadata(h_A, metadata_array, M, K); 159 | 160 | u_int32_t *d_metadata; 161 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 162 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 163 | cudaMemcpyHostToDevice); 164 | 165 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K_A, M, K_A, d_A); 166 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 167 | 168 | kernel<<>>(tensor_map_a, tensor_map_b, d_C, d_metadata); 169 | 170 | cuda_check_error(); 171 | 172 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 173 | 174 | // print_matrix<5>(h_A2, M, K_A); 175 | 176 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 177 | 178 | compare_matrices(h_CPU, h_C, M, N); 179 | 180 | // print_differnce(h_CPU, h_C, M, N, 0); 181 | 182 | return 0; 183 | } 184 | -------------------------------------------------------------------------------- /examples/8_swizzle_manual.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "matrix_utilities.cuh" 15 | #include "profile_utilities.cuh" 16 | #include "wgmma.cuh" 17 | 18 | const int M = 64; 19 | const int N = 8; 20 | const int K = 16; 21 | 22 | const int threads_per_block = 32 * 4; // 4 warps 23 | const int blocks = 1; 24 | 25 | __global__ void kernel(half *A, half *B, half *C) { 26 | // metadata 27 | const int tid = threadIdx.x; 28 | const int warp_id = tid / 32; 29 | const int lane_id = tid % 32; 30 | const int group_id = lane_id >> 2; 31 | const int lane_in_group = lane_id & 3; 32 | 33 | __syncthreads(); 34 | 35 | __align__(16) __shared__ half A_shared[M * K]; 36 | __align__(16) __shared__ half B_shared[K * N]; 37 | 38 | __align__(16) __shared__ half buffer[2 * 64]; 39 | 40 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async 41 | // 8x8 core blocks, we use one thread here to 42 | // easy demonstrate the required layout 43 | if (tid == 0) { 44 | for (int i = 0; i < M; i++) { 45 | for (int j = 0; j < K; j++) { 46 | int block_x = i / 8; 47 | int block_row = i % 8; 48 | int block_y = j / 8; 49 | int block_col = j % 8; 50 | int block_id = block_x * 2 + block_y; 51 | int offset = block_id * 64 + block_row * 8 + block_col; 52 | A_shared[offset] = A[i * K + j]; 53 | } 54 | } 55 | 56 | // swizzle A 57 | for (int pair = 0; pair < 8; pair++) { 58 | 59 | for (int i = 0; i < 8; i++) { 60 | if (i % 2 == 0) { 61 | for (int j = 0; j < 8; j++) { 62 | buffer[i * 8 + j] = 63 | A_shared[pair * 128 + i / 2 * 8 + j]; 64 | } 65 | } else { 66 | for (int j = 0; j < 8; j++) { 67 | buffer[i * 8 + j] = 68 | A_shared[pair * 128 + 64 + i / 2 * 8 + j]; 69 | } 70 | } 71 | } 72 | 73 | for (int i = 0; i < 8; i++) { 74 | if (i % 2 == 0) { 75 | for (int j = 0; j < 8; j++) { 76 | buffer[64 + i * 8 + j] = 77 | A_shared[pair * 128 + 64 + 32 + i / 2 * 8 + j]; 78 | } 79 | } else { 80 | for (int j = 0; j < 8; j++) { 81 | buffer[64 + i * 8 + j] = 82 | A_shared[pair * 128 + 32 + i / 2 * 8 + j]; 83 | } 84 | } 85 | } 86 | 87 | // write back to A_shared 88 | for (int row = 0; row < 16; row++) { 89 | for (int col = 0; col < 8; col++) { 90 | A_shared[pair * 128 + row * 8 + col] = 91 | buffer[row * 8 + col]; 92 | } 93 | } 94 | } 95 | 96 | for (int i = 0; i < K; i++) { 97 | for (int j = 0; j < N; j++) { 98 | int block_x = i / 8; 99 | int block_row = i % 8; 100 | int block_y = j / 8; 101 | int block_col = j % 8; 102 | int block_id = block_x * 1 + block_y; 103 | int offset = block_id * 64 + block_row * 8 + block_col; 104 | B_shared[offset] = B[i * N + j]; 105 | } 106 | } 107 | } 108 | 109 | __syncthreads(); 110 | 111 | // create descriptors for the matrices 112 | GmmaDescriptor desc_a = make_desc_a(A_shared); 113 | GmmaDescriptor desc_b = make_desc_b(B_shared); 114 | 115 | // accumulator 116 | uint32_t c[2] = {}; 117 | 118 | // called whenever the accumulator is accessed 119 | warpgroup_arrive(); 120 | 121 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 122 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 123 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 124 | // imm-scale-a, imme-scale-b, imm-trans-b; 125 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 126 | "{%0, %1}, " // accumulator 127 | "%2, %3, " // matrix a descriptor 128 | "1, " // 0 => D = A*B, 1 => D = D + A*B 129 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 130 | // -1 to a or b 131 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 132 | : "+r"(c[0]), "+r"(c[1]) 133 | : "l"(desc_a), "l"(desc_b)); 134 | 135 | // commit, start the computation 136 | warpgroup_commit_batch(); 137 | 138 | // wait for the previous commit to finish 139 | warpgroup_wait<0>(); 140 | 141 | // thread fence needed for async operations 142 | __threadfence(); 143 | 144 | warpgroup_arrive(); 145 | 146 | uint32_t *C_ptr = reinterpret_cast(C); 147 | 148 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 149 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 150 | 151 | // write back to global memory 152 | C_ptr[offset1] = c[0]; 153 | C_ptr[offset2] = c[1]; 154 | } 155 | 156 | int main() { 157 | 158 | half *d_C; 159 | half h_C[M * N]; 160 | half h_CPU[M * N]; 161 | half h_A[M * K]; 162 | half h_B[K * N]; 163 | 164 | fill_fixed(h_C, M, N, 0); 165 | 166 | fill_random(h_A, M, K); 167 | fill_random(h_B, K, N); 168 | 169 | half *d_A, *d_B; 170 | 171 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 172 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 173 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 174 | 175 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 176 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 177 | 178 | kernel<<>>(d_A, d_B, d_C); 179 | 180 | cuda_check_error(); 181 | 182 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 183 | 184 | // print_matrix(h_C, M, N); 185 | 186 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 187 | 188 | compare_matrices(h_CPU, h_C, M, N); 189 | 190 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 191 | 192 | return 0; 193 | } 194 | -------------------------------------------------------------------------------- /sparse/3_m256_n8_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix_utilities.cuh" 18 | #include "profile_utilities.cuh" 19 | #include "tma_tensor_map.cuh" 20 | #include "wgmma.cuh" 21 | 22 | #pragma nv_diag_suppress static_var_with_dynamic_init 23 | 24 | const int M = 256; 25 | const int N = 8; 26 | const int K = 64; 27 | 28 | // 2:4 format 29 | const int K_A = 32; 30 | 31 | const int threads_per_block = 32 * 4; // 4 warps 32 | const int blocks = 1; 33 | 34 | using barrier = cuda::barrier; 35 | namespace cde = cuda::device::experimental; 36 | 37 | __device__ void MMA_SP_WRAPPER(uint32_t * c, GmmaDescriptor desc_a, GmmaDescriptor desc_b, uint32_t metadata) { 38 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " 39 | "{%0, %1}, " // c 40 | "%2, %3, " // desc A, B 41 | "%4, " // meta 42 | "%5, " // thread selection 43 | "%6, " // scale D 44 | "%7, %8, " // +/- scale A, B 45 | "%9, %10;" // transpose A, B 46 | : "+r"(c[0]), "+r"(c[1]) 47 | : "l"(desc_a), "l"(desc_b), 48 | "r"(metadata), // metadata 49 | "n"(0), // thread selection 50 | "n"(1), // scale D 51 | "n"(1), "n"(1), // +- scale A, B 52 | "n"(0), "n"(1)); // transpose A, B 53 | } 54 | 55 | 56 | __global__ void kernel( 57 | const __grid_constant__ CUtensorMap tensor_map_a, 58 | const __grid_constant__ CUtensorMap tensor_map_b, 59 | half *C, 60 | u_int32_t *metadata_array) { 61 | 62 | const int tid = threadIdx.x; 63 | const int warp_id = tid / 32; 64 | const int lane_id = tid % 32; 65 | const int group_id = lane_id >> 2; 66 | const int lane_in_group = lane_id & 3; 67 | const int lane_in_work_group = lane_in_group % 2; 68 | 69 | __align__(128) __shared__ half A_shared[M * K_A]; 70 | __align__(16) __shared__ half B_shared[K * N]; 71 | 72 | __shared__ barrier bar; 73 | 74 | if (threadIdx.x == 0) { 75 | init(&bar, blockDim.x); 76 | } 77 | __syncthreads(); 78 | 79 | uint64_t token; 80 | if (tid == 0) { 81 | // call the loading api 82 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 83 | 0, bar); 84 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 85 | 0, 0, bar); 86 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared) + sizeof(B_shared)); 87 | } else { 88 | token = bar.arrive(); 89 | } 90 | 91 | bar.wait(cuda::std::move(token)); 92 | __syncthreads(); 93 | 94 | 95 | u_int32_t metadata; 96 | uint metadata_offset; 97 | GmmaDescriptor desc_a, desc_b; 98 | 99 | // divide the 256x64 of A into 4 64x32 tiles and multiply them with the B divided into 2 32x8 tiles 100 | // accumulator 101 | uint32_t c[4][2] = {}; 102 | 103 | desc_b = make_desc(B_shared); 104 | #pragma unroll 105 | for (int m2 = 0; m2 < 4; m2++) { 106 | warpgroup_arrive(); 107 | desc_a = make_desc(A_shared + m2 * 64 * K_A); 108 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + lane_in_work_group * 8 + group_id; 109 | metadata = metadata_array[metadata_offset]; 110 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 111 | } 112 | 113 | desc_b = make_desc(B_shared + 32 * N); 114 | #pragma unroll 115 | for (int m2 = 0; m2 < 4; m2++) { 116 | warpgroup_arrive(); 117 | desc_a = make_desc(A_shared + m2 * 64 * K_A + K_A / 2); 118 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + 8 * 2 + lane_in_work_group * 8 + group_id; 119 | metadata = metadata_array[metadata_offset]; 120 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 121 | } 122 | 123 | // commit, start the computation 124 | warpgroup_commit_batch(); 125 | 126 | // wait for the previous commit to finish 127 | warpgroup_wait<0>(); 128 | 129 | // thread fence needed for async operations 130 | __threadfence(); 131 | 132 | warpgroup_arrive(); 133 | 134 | // store the result 135 | uint32_t *C_ptr = reinterpret_cast(C); 136 | 137 | for (int m2 = 0; m2 < 4; m2++) { 138 | int offset1 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + group_id * N / 2 + lane_in_group; 139 | int offset2 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + (group_id + 8) * N / 2 + lane_in_group; 140 | C_ptr[offset1] = c[m2][0]; 141 | C_ptr[offset2] = c[m2][1]; 142 | } 143 | } 144 | 145 | int main() { 146 | 147 | half *d_C; 148 | half h_C[M * N]; 149 | half h_CPU[M * N]; 150 | half h_A[M * K]; 151 | half h_A2[M * K_A]; 152 | half h_B[K * N]; 153 | 154 | fill_24(h_A, M, K); 155 | fill_random(h_B, K, N); 156 | 157 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 158 | compress24(h_A, h_A2, M, K); 159 | 160 | // print_matrix(h_A2, M, K_A); 161 | 162 | half *d_A, *d_B; 163 | 164 | cudaMalloc((void **)&d_A, M * K_A * sizeof(half)); 165 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 166 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 167 | 168 | cudaMemcpy(d_A, h_A2, M * K_A * sizeof(half), cudaMemcpyHostToDevice); 169 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 170 | 171 | int metadata_size = (M / 16) * (K / 16) * 8; 172 | 173 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 174 | inspect_metadata(h_A, metadata_array, M, K); 175 | 176 | u_int32_t *d_metadata; 177 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 178 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 179 | cudaMemcpyHostToDevice); 180 | 181 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K_A, M, K_A, d_A); 182 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 183 | 184 | kernel<<>>(tensor_map_a, tensor_map_b, d_C, d_metadata); 185 | 186 | cuda_check_error(); 187 | 188 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 189 | 190 | // print_matrix<5>(h_A2, M, K_A); 191 | 192 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 193 | 194 | compare_matrices(h_CPU, h_C, M, N); 195 | 196 | // print_differnce(h_CPU, h_C, M, N, 0); 197 | 198 | return 0; 199 | } 200 | -------------------------------------------------------------------------------- /sparse/2_m64_n8_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix_utilities.cuh" 18 | #include "profile_utilities.cuh" 19 | #include "tma_tensor_map.cuh" 20 | #include "wgmma.cuh" 21 | 22 | #pragma nv_diag_suppress static_var_with_dynamic_init 23 | 24 | const int M = 64; 25 | const int N = 8; 26 | const int K = 64; 27 | 28 | // 2:4 format 29 | const int K_A = 32; 30 | 31 | const int threads_per_block = 32 * 4; // 4 warps 32 | const int blocks = 1; 33 | 34 | using barrier = cuda::barrier; 35 | namespace cde = cuda::device::experimental; 36 | 37 | __global__ void kernel( 38 | const __grid_constant__ CUtensorMap tensor_map_a, 39 | const __grid_constant__ CUtensorMap tensor_map_b, 40 | half *C, 41 | u_int32_t *metadata_array) { 42 | 43 | const int tid = threadIdx.x; 44 | const int warp_id = tid / 32; 45 | const int lane_id = tid % 32; 46 | const int group_id = lane_id >> 2; 47 | const int lane_in_group = lane_id & 3; 48 | const int lane_in_work_group = lane_in_group % 2; 49 | 50 | __align__(128) __shared__ half A_shared[M * K_A]; 51 | __align__(16) __shared__ half B_shared[K * N]; 52 | 53 | __shared__ barrier bar; 54 | 55 | if (threadIdx.x == 0) { 56 | init(&bar, blockDim.x); 57 | } 58 | __syncthreads(); 59 | 60 | uint64_t token; 61 | if (tid == 0) { 62 | // call the loading api 63 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 64 | 0, bar); 65 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 66 | 0, 0, bar); 67 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared) + sizeof(B_shared)); 68 | } else { 69 | token = bar.arrive(); 70 | } 71 | 72 | bar.wait(cuda::std::move(token)); 73 | 74 | __syncthreads(); 75 | 76 | // load metadata 77 | u_int32_t metadata; 78 | uint metadata_offset = warp_id * 8 * 4 + lane_in_work_group * 8 + group_id; 79 | metadata = metadata_array[metadata_offset]; 80 | 81 | __syncthreads(); 82 | 83 | // create descriptors 84 | GmmaDescriptor desc_a = make_desc(A_shared); 85 | GmmaDescriptor desc_b = make_desc(B_shared); 86 | 87 | // accumulator 88 | uint32_t c[2] = {}; 89 | 90 | warpgroup_arrive(); 91 | 92 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " 93 | "{%0, %1}, " // c 94 | "%2, %3, " // desc A, B 95 | "%4, " // meta 96 | "0, " // thread selection 97 | "1, " // scale D 98 | "%7, %8, " // +/- scale A, B 99 | "%9, %10;" // transpose A, B 100 | : "+r"(c[0]), "+r"(c[1]) 101 | : "l"(desc_a), "l"(desc_b), 102 | "r"(metadata), // metadata 103 | "r"(0), // thread selection 104 | "r"(1), // scale D 105 | "n"(1), "n"(1), // +- scale A, B 106 | "n"(0), "n"(1)); // transpose A, B 107 | 108 | desc_a = make_desc(A_shared + K_A / 2); 109 | desc_b = make_desc(B_shared + 32 * N); 110 | 111 | warpgroup_arrive(); 112 | 113 | metadata_offset = warp_id * 8 * 4 + 8 * 2 + lane_in_work_group * 8 + group_id; 114 | metadata = metadata_array[metadata_offset]; 115 | 116 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " 117 | "{%0, %1}, " // c 118 | "%2, %3, " // desc A, B 119 | "%4, " // meta 120 | "0, " // thread selection 121 | "1, " // scale D 122 | "%7, %8, " // +/- scale A, B 123 | "%9, %10;" // transpose A, B 124 | : "+r"(c[0]), "+r"(c[1]) 125 | : "l"(desc_a), "l"(desc_b), 126 | "r"(metadata), // metadata 127 | "r"(0), // thread selection 128 | "r"(1), // scale D 129 | "n"(1), "n"(1), // +- scale A, B 130 | "n"(0), "n"(1)); // transpose A, B 131 | 132 | // commit, start the computation 133 | warpgroup_commit_batch(); 134 | 135 | // wait for the previous commit to finish 136 | warpgroup_wait<0>(); 137 | 138 | // thread fence needed for async operations 139 | __threadfence(); 140 | 141 | warpgroup_arrive(); 142 | 143 | // store the result 144 | uint32_t *C_ptr = reinterpret_cast(C); 145 | 146 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 147 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 148 | 149 | C_ptr[offset1] = c[0]; 150 | C_ptr[offset2] = c[1]; 151 | } 152 | 153 | int main() { 154 | 155 | half *d_C; 156 | half h_C[M * N]; 157 | half h_CPU[M * N]; 158 | half h_A[M * K]; 159 | half h_A2[M * K_A]; 160 | half h_B[K * N]; 161 | 162 | fill_24(h_A, M, K); 163 | fill_random(h_B, K, N); 164 | 165 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 166 | compress24(h_A, h_A2, M, K); 167 | 168 | // print_matrix(h_A2, M, K_A); 169 | 170 | half *d_A, *d_B; 171 | 172 | cudaMalloc((void **)&d_A, M * K_A * sizeof(half)); 173 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 174 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 175 | 176 | cudaMemcpy(d_A, h_A2, M * K_A * sizeof(half), cudaMemcpyHostToDevice); 177 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 178 | 179 | int metadata_size = (M / 16) * (K / 16) * 8; 180 | 181 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 182 | inspect_metadata(h_A, metadata_array, M, K); 183 | 184 | u_int32_t *d_metadata; 185 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 186 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 187 | cudaMemcpyHostToDevice); 188 | 189 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K_A, M, K_A, d_A); 190 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 191 | 192 | kernel<<>>(tensor_map_a, tensor_map_b, d_C, d_metadata); 193 | 194 | cuda_check_error(); 195 | 196 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 197 | 198 | // print_matrix<5>(h_A2, M, K_A); 199 | 200 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 201 | 202 | compare_matrices(h_CPU, h_C, M, N); 203 | 204 | // print_differnce(h_CPU, h_C, M, N, 0); 205 | 206 | return 0; 207 | } 208 | -------------------------------------------------------------------------------- /sparse/4_m256_n16_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix_utilities.cuh" 18 | #include "profile_utilities.cuh" 19 | #include "tma_tensor_map.cuh" 20 | #include "wgmma.cuh" 21 | 22 | #pragma nv_diag_suppress static_var_with_dynamic_init 23 | 24 | const int M = 256; 25 | const int N = 16; 26 | const int K = 64; 27 | 28 | // 2:4 format 29 | const int K_A = 32; 30 | 31 | const int threads_per_block = 32 * 4; // 4 warps 32 | const int blocks = 1; 33 | 34 | using barrier = cuda::barrier; 35 | namespace cde = cuda::device::experimental; 36 | 37 | __device__ void MMA_SP_WRAPPER(uint32_t * c, GmmaDescriptor desc_a, GmmaDescriptor desc_b, uint32_t metadata) { 38 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " 39 | "{%0, %1, %2, %3}, " // c 40 | "%4, %5, " // desc A, B 41 | "%6, " // meta 42 | "%7, " // thread selection 43 | "%8, " // scale D 44 | "%9, %10, " // +/- scale A, B 45 | "%11, %12;" // transpose A, B 46 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 47 | : "l"(desc_a), "l"(desc_b), 48 | "r"(metadata), // metadata 49 | "n"(0), // thread selection 50 | "n"(1), // scale D 51 | "n"(1), "n"(1), // +- scale A, B 52 | "n"(0), "n"(1)); // transpose A, B 53 | } 54 | 55 | 56 | __global__ void kernel( 57 | const __grid_constant__ CUtensorMap tensor_map_a, 58 | const __grid_constant__ CUtensorMap tensor_map_b, 59 | half *C, 60 | u_int32_t *metadata_array) { 61 | 62 | const int tid = threadIdx.x; 63 | const int warp_id = tid / 32; 64 | const int lane_id = tid % 32; 65 | const int group_id = lane_id >> 2; 66 | const int lane_in_group = lane_id & 3; 67 | const int lane_in_work_group = lane_in_group % 2; 68 | 69 | __align__(128) __shared__ half A_shared[M * K_A]; 70 | __align__(16) __shared__ half B_shared[K * N]; 71 | 72 | __shared__ barrier bar; 73 | 74 | if (threadIdx.x == 0) { 75 | init(&bar, blockDim.x); 76 | } 77 | __syncthreads(); 78 | 79 | uint64_t token; 80 | if (tid == 0) { 81 | // call the loading api 82 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 83 | 0, bar); 84 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 85 | 0, 0, bar); 86 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared) + sizeof(B_shared)); 87 | } else { 88 | token = bar.arrive(); 89 | } 90 | 91 | bar.wait(cuda::std::move(token)); 92 | __syncthreads(); 93 | 94 | 95 | u_int32_t metadata; 96 | uint metadata_offset; 97 | GmmaDescriptor desc_a, desc_b; 98 | 99 | // divide the 256x64 of A into 4 64x32 tiles and multiply them with the B divided into 2 32x16 tiles 100 | // accumulator 101 | uint32_t c[4][4] = {}; 102 | 103 | desc_b = make_desc(B_shared); 104 | #pragma unroll 105 | for (int m2 = 0; m2 < 4; m2++) { 106 | warpgroup_arrive(); 107 | desc_a = make_desc(A_shared + m2 * 64 * K_A); 108 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + lane_in_work_group * 8 + group_id; 109 | metadata = metadata_array[metadata_offset]; 110 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 111 | } 112 | 113 | desc_b = make_desc(B_shared + 32 * N); 114 | #pragma unroll 115 | for (int m2 = 0; m2 < 4; m2++) { 116 | warpgroup_arrive(); 117 | desc_a = make_desc(A_shared + m2 * 64 * K_A + K_A / 2); 118 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + 8 * 2 + lane_in_work_group * 8 + group_id; 119 | metadata = metadata_array[metadata_offset]; 120 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 121 | } 122 | 123 | // commit, start the computation 124 | warpgroup_commit_batch(); 125 | 126 | // wait for the previous commit to finish 127 | warpgroup_wait<0>(); 128 | 129 | // thread fence needed for async operations 130 | __threadfence(); 131 | 132 | warpgroup_arrive(); 133 | 134 | // store the result 135 | uint32_t *C_ptr = reinterpret_cast(C); 136 | 137 | for (int m2 = 0; m2 < 4; m2++) { 138 | int offset1 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + group_id * N / 2 + lane_in_group; 139 | int offset2 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + (group_id + 8) * N / 2 + lane_in_group; 140 | C_ptr[offset1] = c[m2][0]; 141 | C_ptr[offset2] = c[m2][1]; 142 | C_ptr[offset1 + 4] = c[m2][2]; 143 | C_ptr[offset2 + 4] = c[m2][3]; 144 | } 145 | } 146 | 147 | int main() { 148 | 149 | half *d_C; 150 | half h_C[M * N]; 151 | half h_CPU[M * N]; 152 | half h_A[M * K]; 153 | half h_A2[M * K_A]; 154 | half h_B[K * N]; 155 | 156 | fill_24(h_A, M, K); 157 | fill_random(h_B, K, N); 158 | 159 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 160 | compress24(h_A, h_A2, M, K); 161 | 162 | // print_matrix(h_A2, M, K_A); 163 | 164 | half *d_A, *d_B; 165 | 166 | cudaMalloc((void **)&d_A, M * K_A * sizeof(half)); 167 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 168 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 169 | 170 | cudaMemcpy(d_A, h_A2, M * K_A * sizeof(half), cudaMemcpyHostToDevice); 171 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 172 | 173 | int metadata_size = (M / 16) * (K / 16) * 8; 174 | 175 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 176 | inspect_metadata(h_A, metadata_array, M, K); 177 | 178 | u_int32_t *d_metadata; 179 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 180 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 181 | cudaMemcpyHostToDevice); 182 | 183 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K_A, M, K_A, d_A); 184 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 185 | 186 | kernel<<>>(tensor_map_a, tensor_map_b, d_C, d_metadata); 187 | 188 | cuda_check_error(); 189 | 190 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 191 | 192 | // print_matrix<5>(h_A2, M, K_A); 193 | 194 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 195 | 196 | compare_matrices(h_CPU, h_C, M, N); 197 | 198 | // print_differnce(h_CPU, h_C, M, N, 0); 199 | 200 | return 0; 201 | } 202 | -------------------------------------------------------------------------------- /dense/4_m64_n16_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 16; 28 | const int K = 64; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map_a, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | // Load A 57 | uint64_t token; 58 | if (tid == 0) { 59 | // call the loading api 60 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 61 | 0, bar); 62 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 63 | 0, 0, bar); 64 | token = cuda::device::barrier_arrive_tx( 65 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 66 | } else { 67 | token = bar.arrive(); 68 | } 69 | 70 | bar.wait(cuda::std::move(token)); 71 | 72 | __syncthreads(); 73 | 74 | // create descriptors for the matrices 75 | GmmaDescriptor desc_a = make_desc(A_shared); 76 | GmmaDescriptor desc_b = make_desc(B_shared); 77 | 78 | // accumulator 79 | uint32_t c[4] = {}; 80 | 81 | // called whenever the accumulator is accessed 82 | warpgroup_arrive(); 83 | 84 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 85 | "{%0, %1, %2, %3}, " // accumulator 86 | "%4, %5, " // matrix a descriptor 87 | "1, " // 0 => D = A*B, 1 => D = D + A*B 88 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 89 | // -1 to a or b 90 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 91 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 92 | : "l"(desc_a), "l"(desc_b)); 93 | 94 | // second step 95 | desc_a = make_desc(A_shared + 16); 96 | desc_b = make_desc(B_shared + 16 * N); 97 | 98 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 99 | "{%0, %1, %2, %3}, " // accumulator 100 | "%4, %5, " // matrix a descriptor 101 | "1, " // 0 => D = A*B, 1 => D = D + A*B 102 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 103 | // -1 to a or b 104 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 105 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 106 | : "l"(desc_a), "l"(desc_b)); 107 | 108 | // third step 109 | desc_a = make_desc(A_shared + 16 + 16); 110 | desc_b = make_desc(B_shared + 16 * N + 16 * N); 111 | 112 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 113 | "{%0, %1, %2, %3}, " // accumulator 114 | "%4, %5, " // matrix a descriptor 115 | "1, " // 0 => D = A*B, 1 => D = D + A*B 116 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 117 | // -1 to a or b 118 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 119 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 120 | : "l"(desc_a), "l"(desc_b)); 121 | 122 | desc_a = make_desc(A_shared + 16 + 16 + 16); 123 | desc_b = make_desc(B_shared + 16 * N + 16 * N + 16 * N); 124 | 125 | asm volatile("wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " 126 | "{%0, %1, %2, %3}, " // accumulator 127 | "%4, %5, " // matrix a descriptor 128 | "1, " // 0 => D = A*B, 1 => D = D + A*B 129 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 130 | // -1 to a or b 131 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 132 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) 133 | : "l"(desc_a), "l"(desc_b)); 134 | 135 | warpgroup_arrive(); 136 | 137 | // commit, start the computation 138 | warpgroup_commit_batch(); 139 | 140 | // wait for the previous commit to finish 141 | warpgroup_wait<0>(); 142 | 143 | // thread fence needed for async operations 144 | __threadfence(); 145 | 146 | warpgroup_arrive(); 147 | 148 | uint32_t *C_ptr = reinterpret_cast(C); 149 | 150 | int offset1 = warp_id * 16 * 8 + group_id * 8 + lane_in_group; 151 | int offset2 = warp_id * 16 * 8 + (group_id + 8) * 8 + lane_in_group; 152 | 153 | // write back to global memory 154 | C_ptr[offset1] = c[0]; 155 | C_ptr[offset2] = c[1]; 156 | C_ptr[offset1 + 4] = c[2]; 157 | C_ptr[offset2 + 4] = c[3]; 158 | } 159 | 160 | int main() { 161 | 162 | half *d_C; 163 | half h_C[M * N]; 164 | half h_CPU[M * N]; 165 | half h_A[M * K]; 166 | half h_B[K * N]; 167 | 168 | fill_fixed(h_C, M, N, 0); 169 | 170 | fill_random(h_A, M, K); 171 | fill_random(h_B, K, N); 172 | 173 | half *d_A, *d_B; 174 | 175 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 176 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 177 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 178 | 179 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 180 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 181 | 182 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K, M, K, d_A); 183 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 184 | 185 | kernel<<>>(tensor_map_a, tensor_map_b, d_C); 186 | 187 | cuda_check_error(); 188 | 189 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 190 | 191 | // print_matrix(h_C, M, N); 192 | 193 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 194 | 195 | compare_matrices(h_CPU, h_C, M, N); 196 | 197 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 198 | 199 | return 0; 200 | } -------------------------------------------------------------------------------- /dense/3_m64_n8_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the dense wgmma instructions 3 | to perform matrix multiplication 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix_utilities.cuh" 16 | #include "profile_utilities.cuh" 17 | #include "tma_tensor_map.cuh" 18 | #include "wgmma.cuh" 19 | 20 | // Suppress warning about barrier in shared memory 21 | #pragma nv_diag_suppress static_var_with_dynamic_init 22 | 23 | using barrier = cuda::barrier; 24 | namespace cde = cuda::device::experimental; 25 | 26 | const int M = 64; 27 | const int N = 8; 28 | const int K = 64; 29 | 30 | const int threads_per_block = 32 * 4; // 4 warps 31 | const int blocks = 1; 32 | 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map_a, 34 | const __grid_constant__ CUtensorMap tensor_map_b, 35 | half *C) { 36 | 37 | // metadata 38 | const int tid = threadIdx.x; 39 | const int warp_id = tid / 32; 40 | const int lane_id = tid % 32; 41 | const int group_id = lane_id >> 2; 42 | const int lane_in_group = lane_id & 3; 43 | 44 | __syncthreads(); 45 | 46 | __align__(128) __shared__ half A_shared[M * K]; 47 | __align__(16) __shared__ half B_shared[K * N]; 48 | 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | init(&bar, blockDim.x); 53 | } 54 | __syncthreads(); 55 | 56 | uint64_t token; 57 | if (tid == 0) { 58 | // call the loading api 59 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 60 | 0, bar); 61 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 62 | 0, 0, bar); 63 | token = cuda::device::barrier_arrive_tx( 64 | bar, 1, sizeof(A_shared) + sizeof(B_shared)); 65 | } else { 66 | token = bar.arrive(); 67 | } 68 | 69 | bar.wait(cuda::std::move(token)); 70 | 71 | __syncthreads(); 72 | 73 | // create descriptors for the matrices 74 | GmmaDescriptor desc_a = make_desc(A_shared); 75 | GmmaDescriptor desc_b = make_desc(B_shared); 76 | 77 | // accumulator 78 | uint32_t c[2] = {}; 79 | 80 | // called whenever the accumulator is accessed 81 | warpgroup_arrive(); 82 | 83 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a-desc, b-desc, 84 | // scale-d, imm-scale-a, imme-scale-b, imm-trans-a, imm-trans-b; 85 | // wgmma.mma_async.sync.aligned.shape.dtype.f16.f16 d, a, b-desc, scale-d, 86 | // imm-scale-a, imme-scale-b, imm-trans-b; 87 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 88 | "{%0, %1}, " // accumulator 89 | "%2, %3, " // matrix a descriptor 90 | "1, " // 0 => D = A*B, 1 => D = D + A*B 91 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 92 | // -1 to a or b 93 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 94 | : "+r"(c[0]), "+r"(c[1]) 95 | : "l"(desc_a), "l"(desc_b)); 96 | 97 | // second step 98 | desc_a = make_desc(A_shared + 16); 99 | desc_b = make_desc(B_shared + 16 * 8); 100 | 101 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 102 | "{%0, %1}, " // accumulator 103 | "%2, %3, " // matrix a descriptor 104 | "1, " // 0 => D = A*B, 1 => D = D + A*B 105 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 106 | // -1 to a or b 107 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 108 | : "+r"(c[0]), "+r"(c[1]) 109 | : "l"(desc_a), "l"(desc_b)); 110 | 111 | // third step 112 | desc_a = make_desc(A_shared + 16 + 16); 113 | desc_b = make_desc(B_shared + 16 * 8 + 16 * 8); 114 | 115 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 116 | "{%0, %1}, " // accumulator 117 | "%2, %3, " // matrix a descriptor 118 | "1, " // 0 => D = A*B, 1 => D = D + A*B 119 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 120 | // -1 to a or b 121 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 122 | : "+r"(c[0]), "+r"(c[1]) 123 | : "l"(desc_a), "l"(desc_b)); 124 | 125 | desc_a = make_desc(A_shared + 16 + 16 + 16); 126 | desc_b = make_desc(B_shared + 16 * 8 + 16 * 8 + 16 * 8); 127 | 128 | asm volatile("wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " 129 | "{%0, %1}, " // accumulator 130 | "%2, %3, " // matrix a descriptor 131 | "1, " // 0 => D = A*B, 1 => D = D + A*B 132 | "1, 1, " // 0 => no scaling, 1 => scaling, scaling means times 133 | // -1 to a or b 134 | "0, 1;" // transpose a and b, 0 => no transpose, 1 => transpose 135 | : "+r"(c[0]), "+r"(c[1]) 136 | : "l"(desc_a), "l"(desc_b)); 137 | 138 | warpgroup_arrive(); 139 | 140 | // commit, start the computation 141 | warpgroup_commit_batch(); 142 | 143 | // wait for the previous commit to finish 144 | warpgroup_wait<0>(); 145 | 146 | // thread fence needed for async operations 147 | __threadfence(); 148 | 149 | warpgroup_arrive(); 150 | 151 | uint32_t *C_ptr = reinterpret_cast(C); 152 | 153 | int offset1 = warp_id * 16 * 4 + group_id * 4 + lane_in_group; 154 | int offset2 = warp_id * 16 * 4 + (group_id + 8) * 4 + lane_in_group; 155 | 156 | // write back to global memory 157 | C_ptr[offset1] = c[0]; 158 | C_ptr[offset2] = c[1]; 159 | } 160 | 161 | int main() { 162 | 163 | half *d_C; 164 | half h_C[M * N]; 165 | half h_CPU[M * N]; 166 | half h_A[M * K]; 167 | half h_B[K * N]; 168 | 169 | fill_fixed(h_C, M, N, 0); 170 | 171 | fill_random(h_A, M, K); 172 | fill_random(h_B, K, N); 173 | 174 | half *d_A, *d_B; 175 | 176 | cudaMalloc((void **)&d_A, M * K * sizeof(half)); 177 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 178 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 179 | 180 | cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice); 181 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 182 | 183 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K, M, K, d_A); 184 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 185 | 186 | kernel<<>>(tensor_map_a, tensor_map_b, d_C); 187 | 188 | cuda_check_error(); 189 | 190 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 191 | 192 | // print_matrix<5>(h_A, M, K); 193 | 194 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 195 | 196 | compare_matrices(h_CPU, h_C, M, N); 197 | 198 | // print_differnce(h_C, h_CPU, M, N, 0.0f); 199 | 200 | return 0; 201 | } -------------------------------------------------------------------------------- /sparse/5_m256_n32_k64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code demonstrates how to use the sparse wgmma instructions 3 | to perform matrix multiplication 4 | 5 | Sparse means matrix A follows a 2:4 format 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix_utilities.cuh" 18 | #include "profile_utilities.cuh" 19 | #include "tma_tensor_map.cuh" 20 | #include "wgmma.cuh" 21 | 22 | #pragma nv_diag_suppress static_var_with_dynamic_init 23 | 24 | const int M = 256; 25 | const int N = 32; 26 | const int K = 64; 27 | 28 | // 2:4 format 29 | const int K_A = 32; 30 | 31 | const int threads_per_block = 32 * 4; // 4 warps 32 | const int blocks = 1; 33 | 34 | using barrier = cuda::barrier; 35 | namespace cde = cuda::device::experimental; 36 | 37 | __device__ void MMA_SP_WRAPPER(uint32_t * c, GmmaDescriptor desc_a, GmmaDescriptor desc_b, uint32_t metadata) { 38 | asm volatile("wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " 39 | "{%0, %1, %2, %3, %4, %5, %6, %7}, " // c 40 | "%8, %9, " // desc A, B 41 | "%10, " // meta 42 | "%11, " // thread selection 43 | "%12, " // scale D 44 | "%13, %14, " // +/- scale A, B 45 | "%15, %16;" // transpose A, B 46 | : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) 47 | : "l"(desc_a), "l"(desc_b), 48 | "r"(metadata), // metadata 49 | "n"(0), // thread selection 50 | "n"(1), // scale D 51 | "n"(1), "n"(1), // +- scale A, B 52 | "n"(0), "n"(1)); // transpose A, B 53 | } 54 | 55 | 56 | __global__ void kernel( 57 | const __grid_constant__ CUtensorMap tensor_map_a, 58 | const __grid_constant__ CUtensorMap tensor_map_b, 59 | half *C, 60 | u_int32_t *metadata_array) { 61 | 62 | const int tid = threadIdx.x; 63 | const int warp_id = tid / 32; 64 | const int lane_id = tid % 32; 65 | const int group_id = lane_id >> 2; 66 | const int lane_in_group = lane_id & 3; 67 | const int lane_in_work_group = lane_in_group % 2; 68 | 69 | __align__(128) __shared__ half A_shared[M * K_A]; 70 | __align__(16) __shared__ half B_shared[K * N]; 71 | 72 | __shared__ barrier bar; 73 | 74 | if (threadIdx.x == 0) { 75 | init(&bar, blockDim.x); 76 | } 77 | __syncthreads(); 78 | 79 | uint64_t token; 80 | if (tid == 0) { 81 | // call the loading api 82 | cde::cp_async_bulk_tensor_2d_global_to_shared(A_shared, &tensor_map_a, 0, 83 | 0, bar); 84 | cde::cp_async_bulk_tensor_2d_global_to_shared(B_shared, &tensor_map_b, 85 | 0, 0, bar); 86 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(A_shared) + sizeof(B_shared)); 87 | } else { 88 | token = bar.arrive(); 89 | } 90 | 91 | bar.wait(cuda::std::move(token)); 92 | __syncthreads(); 93 | 94 | 95 | u_int32_t metadata; 96 | uint metadata_offset; 97 | GmmaDescriptor desc_a, desc_b; 98 | 99 | // divide the 256x64 of A into 4 64x32 tiles and multiply them with the B divided into 2 32x32 tiles 100 | // accumulator 101 | uint32_t c[4][8] = {}; 102 | 103 | desc_b = make_desc(B_shared); 104 | #pragma unroll 105 | for (int m2 = 0; m2 < 4; m2++) { 106 | warpgroup_arrive(); 107 | desc_a = make_desc(A_shared + m2 * 64 * K_A); 108 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + lane_in_work_group * 8 + group_id; 109 | metadata = metadata_array[metadata_offset]; 110 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 111 | } 112 | 113 | desc_b = make_desc(B_shared + 32 * N); 114 | #pragma unroll 115 | for (int m2 = 0; m2 < 4; m2++) { 116 | warpgroup_arrive(); 117 | desc_a = make_desc(A_shared + m2 * 64 * K_A + K_A / 2); 118 | metadata_offset = m2 * 8 * 4 * 4 + warp_id * 8 * 4 + 8 * 2 + lane_in_work_group * 8 + group_id; 119 | metadata = metadata_array[metadata_offset]; 120 | MMA_SP_WRAPPER(c[m2], desc_a, desc_b, metadata); 121 | } 122 | 123 | // commit, start the computation 124 | warpgroup_commit_batch(); 125 | 126 | // wait for the previous commit to finish 127 | warpgroup_wait<0>(); 128 | 129 | // thread fence needed for async operations 130 | __threadfence(); 131 | 132 | warpgroup_arrive(); 133 | 134 | // store the result 135 | uint32_t *C_ptr = reinterpret_cast(C); 136 | 137 | for (int m2 = 0; m2 < 4; m2++) { 138 | int offset1 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + group_id * N / 2 + lane_in_group; 139 | int offset2 = m2 * 64 * N / 2 + warp_id * 16 * N / 2 + (group_id + 8) * N / 2 + lane_in_group; 140 | C_ptr[offset1] = c[m2][0]; 141 | C_ptr[offset2] = c[m2][1]; 142 | C_ptr[offset1 + 4] = c[m2][2]; 143 | C_ptr[offset2 + 4] = c[m2][3]; 144 | C_ptr[offset1 + 4 + 4] = c[m2][4]; 145 | C_ptr[offset2 + 4 + 4] = c[m2][5]; 146 | C_ptr[offset1 + 4 + 4 + 4] = c[m2][6]; 147 | C_ptr[offset2 + 4 + 4 + 4] = c[m2][7]; 148 | } 149 | } 150 | 151 | int main() { 152 | 153 | half *d_C; 154 | half h_C[M * N]; 155 | half h_CPU[M * N]; 156 | half h_A[M * K]; 157 | half h_A2[M * K_A]; 158 | half h_B[K * N]; 159 | 160 | fill_24(h_A, M, K); 161 | fill_random(h_B, K, N); 162 | 163 | // extract the non-zeros in each 2:4 tile to a compressed matrix A2 164 | compress24(h_A, h_A2, M, K); 165 | 166 | // print_matrix(h_A2, M, K_A); 167 | 168 | half *d_A, *d_B; 169 | 170 | cudaMalloc((void **)&d_A, M * K_A * sizeof(half)); 171 | cudaMalloc((void **)&d_B, K * N * sizeof(half)); 172 | cudaMalloc((void **)&d_C, M * N * sizeof(half)); 173 | 174 | cudaMemcpy(d_A, h_A2, M * K_A * sizeof(half), cudaMemcpyHostToDevice); 175 | cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice); 176 | 177 | int metadata_size = (M / 16) * (K / 16) * 8; 178 | 179 | u_int32_t *metadata_array = new u_int32_t[metadata_size]; 180 | inspect_metadata(h_A, metadata_array, M, K); 181 | 182 | u_int32_t *d_metadata; 183 | cudaMalloc((void **)&d_metadata, metadata_size * sizeof(u_int32_t)); 184 | cudaMemcpy(d_metadata, metadata_array, metadata_size * sizeof(u_int32_t), 185 | cudaMemcpyHostToDevice); 186 | 187 | CUtensorMap tensor_map_a = create_2d_tensor_map(M, K_A, M, K_A, d_A); 188 | CUtensorMap tensor_map_b = create_2d_tensor_map(K, N, K, N, d_B); 189 | 190 | kernel<<>>(tensor_map_a, tensor_map_b, d_C, d_metadata); 191 | 192 | cuda_check_error(); 193 | 194 | cudaMemcpy(h_C, d_C, M * N * sizeof(half), cudaMemcpyDeviceToHost); 195 | 196 | // print_matrix<5>(h_A2, M, K_A); 197 | 198 | CPU_gemm(h_A, h_B, h_CPU, M, N, K); 199 | 200 | compare_matrices(h_CPU, h_C, M, N); 201 | 202 | // print_differnce(h_CPU, h_C, M, N, 0); 203 | 204 | return 0; 205 | } 206 | -------------------------------------------------------------------------------- /headers/host/matrix_utilities.cuh: -------------------------------------------------------------------------------- 1 | // helper functions for matrix 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #ifndef CUDA_CHECK 11 | #define CUDA_CHECK(callstr) \ 12 | { \ 13 | cudaError_t error_code = callstr; \ 14 | if (error_code != cudaSuccess) \ 15 | { \ 16 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ 17 | assert(0); \ 18 | } \ 19 | } 20 | #endif 21 | 22 | half rand_half() 23 | { 24 | return __float2half(5.0f * rand() / RAND_MAX); 25 | } 26 | 27 | int rand_int(int max) 28 | { 29 | return rand() % max; 30 | } 31 | 32 | template 33 | void print_matrix(half *matrix, int rows, int cols) 34 | { 35 | const std::string str = "%." + std::to_string(decimal) + "f "; 36 | for (int i = 0; i < rows; i++) 37 | { 38 | for (int j = 0; j < cols; j++) 39 | { 40 | printf(str.c_str(), __half2float(matrix[i * cols + j])); 41 | } 42 | printf("\n"); 43 | } 44 | printf("\n"); 45 | } 46 | 47 | void print_matrix(int *matrix, int rows, int cols) 48 | { 49 | for (int i = 0; i < rows; i++) 50 | { 51 | for (int j = 0; j < cols; j++) 52 | { 53 | printf("%d ", matrix[i * cols + j]); 54 | } 55 | printf("\n"); 56 | } 57 | printf("\n"); 58 | } 59 | 60 | void fill_random(half *matrix, int rows, int cols) 61 | { 62 | for (int i = 0; i < rows; i++) 63 | { 64 | for (int j = 0; j < cols; j++) 65 | { 66 | float value = 0.1f * rand() / RAND_MAX; 67 | matrix[i * cols + j] = __float2half(value); 68 | } 69 | } 70 | } 71 | 72 | void fill_fixed(half *matrix, int rows, int cols, float value) 73 | { 74 | for (int i = 0; i < rows; i++) 75 | { 76 | for (int j = 0; j < cols; j++) 77 | { 78 | matrix[i * cols + j] = __float2half(value); 79 | } 80 | } 81 | } 82 | 83 | void fill_tile(half *matrix, int rows, int cols) 84 | { 85 | for (int i = 0; i < rows; i++) 86 | { 87 | for (int j = 0; j < cols; j++) 88 | { 89 | if (i / 8 == 0 && j / 8 == 0) 90 | { 91 | matrix[i * cols + j] = __float2half(1.0f); 92 | } 93 | else 94 | { 95 | matrix[i * cols + j] = __float2half(0.0f); 96 | } 97 | } 98 | } 99 | } 100 | 101 | void fill_rowwise(int *matrix, int rows, int cols) { 102 | for (int r = 0; r < rows; r++) { 103 | for (int c = 0; c < cols; c++) { 104 | matrix[r * cols + c] = r * rows + c; 105 | } 106 | } 107 | } 108 | 109 | void transpose(half *matrix, int rows, int cols) { 110 | // Create a temporary matrix to store the result 111 | half* temp = new half[rows * cols]; 112 | 113 | // Transpose the matrix 114 | for (int i = 0; i < rows; ++i) { 115 | for (int j = 0; j < cols; ++j) { 116 | // Swap element at (i, j) to (j, i) 117 | temp[j * rows + i] = matrix[i * cols + j]; 118 | } 119 | } 120 | 121 | // Copy the transposed matrix back to the original matrix 122 | for (int i = 0; i < rows * cols; ++i) { 123 | matrix[i] = temp[i]; 124 | } 125 | 126 | // Free the temporary matrix 127 | delete[] temp; 128 | } 129 | 130 | 131 | // element in each subtile has the same value, 132 | // which is their tile number in row major order 133 | void fill_tilewise(int *matrix, int rows, int cols, int tile_size_row, int tile_size_col) 134 | { 135 | for (int i = 0; i < rows; i++) 136 | { 137 | for (int j = 0; j < cols; j++) 138 | { 139 | int id = (i / tile_size_row) * (cols / tile_size_col) + j / tile_size_col; 140 | matrix[i * cols + j] = id % 10; 141 | } 142 | } 143 | } 144 | 145 | void fill_tilewise(half *matrix, int rows, int cols, int tile_size_row, int tile_size_col) 146 | { 147 | for (int i = 0; i < rows; i++) 148 | { 149 | for (int j = 0; j < cols; j++) 150 | { 151 | int id = (i / tile_size_row) * (cols / tile_size_col) + j / tile_size_col; 152 | matrix[i * cols + j] = __float2half(id % 10); 153 | } 154 | } 155 | } 156 | 157 | void CPU_gemm(half *A, half *B, half *C, int M, int N, int K) 158 | { 159 | for (int i = 0; i < M; i++) 160 | { 161 | for (int j = 0; j < N; j++) 162 | { 163 | C[i * N + j] = 0; 164 | for (int k = 0; k < K; k++) 165 | { 166 | float a = __half2float(A[i * K + k]); 167 | float b = __half2float(B[k * N + j]); 168 | float c = __half2float(C[i * N + j]); 169 | float new_c = a * b + c; 170 | C[i * N + j] = __float2half(new_c); 171 | } 172 | } 173 | } 174 | } 175 | 176 | void compare_matrices(half *A, half *B, int rows, int cols) 177 | { 178 | float total_diff = 0.0; 179 | int total_elements = rows * cols; 180 | 181 | for (int i = 0; i < rows; i++) 182 | { 183 | for (int j = 0; j < cols; j++) 184 | { 185 | float a = __half2float(A[i * cols + j]); 186 | float b = __half2float(B[i * cols + j]); 187 | total_diff += fabs((a - b) / a); 188 | } 189 | } 190 | 191 | float percentage_diff = (total_diff / total_elements) * 100; 192 | printf("Total error: %.2f%%\n", percentage_diff); 193 | } 194 | 195 | void print_differnce(half *A, half *B, int rows, int cols, float tolerance) 196 | { 197 | for (int i = 0; i < rows; i++) 198 | { 199 | for (int j = 0; j < cols; j++) 200 | { 201 | float a = __half2float(A[i * cols + j]); 202 | float b = __half2float(B[i * cols + j]); 203 | bool is_same = a - tolerance < b && a + tolerance > b; 204 | if (!is_same) 205 | { 206 | printf("Error at (%d, %d) : %f != %f\n", i, j, a, b); 207 | } 208 | } 209 | } 210 | } 211 | 212 | void compress24(half *dense, half *sparse, int rows, int cols) 213 | { 214 | assert(rows * cols % 4 == 0); 215 | 216 | memset(sparse, 0, rows * cols / 2 * sizeof(half)); 217 | 218 | int counter; 219 | 220 | for (int i = 0; i < rows * cols; i += 4) 221 | { 222 | int sparse_offset = i / 2; 223 | 224 | counter = 0; 225 | 226 | for (int j = 0; j < 4; j++) 227 | { 228 | float value = __half2float(dense[i + j]); 229 | if (value != 0) 230 | { 231 | assert(counter < 2); 232 | sparse[sparse_offset + counter] = dense[i + j]; 233 | counter++; 234 | } 235 | } 236 | } 237 | } 238 | 239 | void fill_24(half *matrix, int rows, int cols) 240 | { 241 | assert(rows * cols % 4 == 0); 242 | 243 | for (int i = 0; i < rows * cols; i += 4) 244 | { 245 | matrix[i] = 0.0; 246 | matrix[i + 1] = 0.0; 247 | matrix[i + 2] = 0.0; 248 | matrix[i + 3] = 0.0; 249 | 250 | int position1 = rand() % 4; 251 | int position2 = rand() % 4; 252 | 253 | // position2 = position2 == position1 ? (position2 + 1) % 4 : position2; 254 | 255 | // matrix[i + position1] = __float2half(1.0f); 256 | // matrix[i + position2] = __float2half(1.0f); 257 | 258 | matrix[i + position1] = __float2half(rand_half()); 259 | matrix[i + position2] = __float2half(rand_half()); 260 | } 261 | } 262 | 263 | __host__ int inspect_metadata(half *mat, u_int32_t *meta, int M, int K) 264 | { 265 | std::map metaMap; 266 | 267 | metaMap["1100"] = 0x4; 268 | metaMap["1010"] = 0x8; 269 | metaMap["1001"] = 0xC; 270 | metaMap["0110"] = 0x9; 271 | metaMap["0101"] = 0xD; 272 | metaMap["0011"] = 0xE; 273 | 274 | metaMap["1000"] = 0x0; 275 | metaMap["0100"] = 0x1; 276 | metaMap["0010"] = 0x2; 277 | metaMap["0001"] = 0x3; 278 | 279 | metaMap["0000"] = 0xF; 280 | 281 | const int total_size = (M / 16) * (K / 16); 282 | 283 | int zero_tile = 0; 284 | 285 | for (int m = 0; m < M / 16; m++) 286 | { 287 | for (int k = 0; k < K / 16; k++) 288 | { 289 | for (int m2 = 0; m2 < 8; m2++) 290 | { 291 | unsigned int metadata = 0; 292 | for (int k2 = 0; k2 < 4; k2++) 293 | { 294 | std::string key = ""; 295 | int counter = 0; 296 | for (int i = 0; i < 4; i++) 297 | { 298 | int index = (m * 16 + m2) * K + k * 16 + k2 * 4 + i; 299 | float value = __half2float(mat[index]); 300 | 301 | if (value != 0.0f) 302 | { 303 | key += "1"; 304 | counter++; 305 | } 306 | else 307 | { 308 | key += "0"; 309 | } 310 | } 311 | 312 | metadata |= metaMap[key] << (k2 * 4); 313 | 314 | if (counter == 0) 315 | { 316 | zero_tile++; 317 | } 318 | } 319 | for (int k2 = 0; k2 < 4; k2++) 320 | { 321 | std::string key = ""; 322 | int counter = 0; 323 | for (int i = 0; i < 4; i++) 324 | { 325 | int index = (m * 16 + m2 + 8) * K + k * 16 + k2 * 4 + i; 326 | float value = __half2float(mat[index]); 327 | 328 | if (value != 0.0f) 329 | { 330 | key += "1"; 331 | counter++; 332 | } 333 | else 334 | { 335 | key += "0"; 336 | } 337 | } 338 | 339 | metadata |= metaMap[key] << (k2 * 4 + 16); 340 | 341 | if (counter == 0) 342 | { 343 | zero_tile++; 344 | } 345 | } 346 | int blockId = m * K / 16 + k; 347 | 348 | meta[blockId * 8 + m2] = metadata; 349 | } 350 | } 351 | } 352 | 353 | printf("zero tile: %d\n", zero_tile); 354 | 355 | double persentage = (double)zero_tile / (double)(M * K / 4) * 100.0; 356 | 357 | printf("zero tile persentage: %lf\n", persentage); 358 | 359 | return total_size * 8; 360 | } 361 | --------------------------------------------------------------------------------