├── Makefile ├── README.md ├── swizzle.cu ├── transpose_naive.cu ├── transpose_swizzle.cu ├── transpose_swizzle_batched.cu ├── transpose_swizzle_batched_for_profile.cu └── utils.h /Makefile: -------------------------------------------------------------------------------- 1 | NVCC_FLAGS = -std=c++17 -O3 -DNDEBUG -w 2 | NVCC_LDFLAGS = -lcublas -lcuda 3 | OUT_DIR = out 4 | PROFILE_DIR = profile 5 | 6 | CUDA_OUTPUT_FILE = -o $(OUT_DIR)/$@ 7 | NCU_PATH := $(shell which ncu) 8 | NCU_COMMAND = $(NCU_PATH) --set full --import-source yes 9 | 10 | NVCC_FLAGS += --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -Xcompiler=-fPIE -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing 11 | NVCC_FLAGS += -arch=sm_90a 12 | 13 | NVCC_BASE = nvcc $(NVCC_FLAGS) $(NVCC_LDFLAGS) -lineinfo 14 | 15 | swizzle: swizzle.cu 16 | $(NVCC_BASE) $^ $(CUDA_OUTPUT_FILE) 17 | 18 | transpose_naive: transpose_naive.cu 19 | $(NVCC_BASE) $^ $(CUDA_OUTPUT_FILE) 20 | 21 | transpose_swizzle: transpose_swizzle.cu 22 | $(NVCC_BASE) $^ $(CUDA_OUTPUT_FILE) 23 | 24 | transpose_swizzle_batched: transpose_swizzle_batched.cu 25 | $(NVCC_BASE) $^ $(CUDA_OUTPUT_FILE) 26 | 27 | transpose_swizzle_batched_for_profile: transpose_swizzle_batched_for_profile.cu 28 | $(NVCC_BASE) $^ $(CUDA_OUTPUT_FILE) 29 | 30 | transpose_swizzle_batched_profile: transpose_swizzle_batched_for_profile 31 | $(NCU_COMMAND) -o $(PROFILE_DIR)/$@ -f $(OUT_DIR)/$^ 32 | 33 | compile_all: 34 | make swizzle 35 | make transpose_naive 36 | make transpose_swizzle 37 | make transpose_swizzle_batched 38 | 39 | run_all: compile_all 40 | ./$(OUT_DIR)/swizzle 41 | ./$(OUT_DIR)/transpose_naive 42 | ./$(OUT_DIR)/transpose_swizzle 43 | ./$(OUT_DIR)/transpose_swizzle_batched 44 | 45 | clean: 46 | rm $(OUT_DIR)/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Effective matrix transpose 2 | 3 | Improve matrix transpose step by step. 4 | 5 | Please see [my blogpost](http://veitner.bearblog.dev/making-matrix-transpose-really-fast-on-hopper-gpus/) for a detailed explanation. 6 | 7 | ## Performance Comparison 8 | 9 | | Kernel | Bandwidth (GB/s) | % of Max Bandwidth | Implementation | 10 | |--------|------------------|-------------------|----------------| 11 | | transpose_naive | 875.46 | 26.5291% | Custom | 12 | | transpose_swizzle | 1251.76 | 37.9323% | Custom | 13 | | transpose_swizzle_batched | 2771.35 | 83.9804% | Custom | 14 | -------------------------------------------------------------------------------- /swizzle.cu: -------------------------------------------------------------------------------- 1 | #include // CUtensormap 2 | #include 3 | 4 | #include 5 | 6 | using barrier = cuda::barrier; 7 | namespace cde = cuda::device::experimental; 8 | 9 | #include "utils.h" 10 | 11 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) 12 | void check(cudaError_t err, char const *func, char const *file, int line) { 13 | if (err != cudaSuccess) { 14 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 15 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 16 | std::exit(EXIT_FAILURE); 17 | } 18 | } 19 | 20 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__) 21 | void checkLast(char const *file, int line) { 22 | cudaError_t const err{cudaGetLastError()}; 23 | if (err != cudaSuccess) { 24 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 25 | std::cerr << cudaGetErrorString(err) << std::endl; 26 | std::exit(EXIT_FAILURE); 27 | } 28 | } 29 | 30 | template 31 | __device__ int calculate_col_swizzle(int row, int col) { 32 | int i16 = (row * BLOCK_SIZE + col) * 4 >> 4; 33 | int y16 = i16 >> 3; 34 | int x16 = i16 & 7; 35 | int x16_swz = y16 ^ x16; 36 | return ((x16_swz * 4) & (BLOCK_SIZE - 1)) + (col & 3); 37 | } 38 | 39 | template 40 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map) { 41 | // The destination shared memory buffer of a bulk tensor operation should be 42 | // 128 byte aligned. 43 | __shared__ alignas(1024) int smem_buffer[BLOCK_SIZE * BLOCK_SIZE]; 44 | 45 | // Coordinates for upper left tile in GMEM. 46 | int x = blockIdx.x * BLOCK_SIZE; 47 | int y = blockIdx.y * BLOCK_SIZE; 48 | 49 | int col = threadIdx.x % BLOCK_SIZE; 50 | int row = threadIdx.x / BLOCK_SIZE; 51 | 52 | int col_swizzle = calculate_col_swizzle(row, col); 53 | 54 | // Initialize shared memory barrier with the number of threads participating in 55 | // the barrier. 56 | #pragma nv_diag_suppress static_var_with_dynamic_init 57 | __shared__ barrier bar; 58 | 59 | if (threadIdx.x == 0) { 60 | // Initialize barrier. All `blockDim.x` threads in block participate. 61 | init(&bar, blockDim.x); 62 | // Make initialized barrier visible in async proxy. 63 | cde::fence_proxy_async_shared_cta(); 64 | } 65 | // Syncthreads so initialized barrier is visible to all threads. 66 | __syncthreads(); 67 | 68 | barrier::arrival_token token; 69 | if (threadIdx.x == 0) { 70 | // Initiate bulk tensor copy. 71 | cde::cp_async_bulk_tensor_2d_global_to_shared(&smem_buffer, &tensor_map, x, 72 | y, bar); 73 | // Arrive on the barrier and tell how many bytes are expected to come in. 74 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 75 | } else { 76 | // Other threads just arrive. 77 | token = bar.arrive(); 78 | } 79 | // Wait for the data to have arrived. 80 | bar.wait(std::move(token)); 81 | 82 | // Symbolically modify a value in shared memory. 83 | smem_buffer[row * BLOCK_SIZE + col_swizzle] = (row * BLOCK_SIZE + col) % 32; 84 | 85 | // Wait for shared memory writes to be visible to TMA engine. 86 | cde::fence_proxy_async_shared_cta(); 87 | __syncthreads(); 88 | // After syncthreads, writes by all threads are visible to TMA engine. 89 | 90 | // Initiate TMA transfer to copy shared memory to global memory 91 | if (threadIdx.x == 0) { 92 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map, x, y, 93 | &smem_buffer); 94 | // Wait for TMA transfer to have finished reading shared memory. 95 | // Create a "bulk async-group" out of the previous bulk copy operation. 96 | cde::cp_async_bulk_commit_group(); 97 | // Wait for the group to have completed reading from shared memory. 98 | cde::cp_async_bulk_wait_group_read<0>(); 99 | } 100 | 101 | // Destroy barrier. This invalidates the memory region of the barrier. If 102 | // further computations were to take place in the kernel, this allows the 103 | // memory location of the shared memory barrier to be reused. 104 | if (threadIdx.x == 0) { 105 | (&bar)->~barrier(); 106 | } 107 | } 108 | 109 | int main() { 110 | const int GMEM_WIDTH = 32; 111 | const int GMEM_HEIGHT = 32; 112 | const int BLOCK_SIZE = 32; 113 | const int SMEM_WIDTH = BLOCK_SIZE; 114 | const int SMEM_HEIGHT = BLOCK_SIZE; 115 | const size_t SIZE = GMEM_HEIGHT * GMEM_WIDTH * sizeof(int); 116 | std::cout << BLOCK_SIZE * BLOCK_SIZE * sizeof(int) << std::endl; 117 | 118 | int *h_in = new int[GMEM_HEIGHT * GMEM_WIDTH]; 119 | int *h_out = new int[GMEM_HEIGHT * GMEM_WIDTH]; 120 | 121 | srand(42); 122 | for (int i = 0; i < GMEM_HEIGHT * GMEM_WIDTH; ++i) { 123 | h_in[i] = rand() % 9; 124 | } 125 | 126 | // std::cout << "Initial matrix:" << std::endl; 127 | // utils::printMatrix(h_in, GMEM_HEIGHT, GMEM_WIDTH); 128 | // std::cout << std::endl; 129 | 130 | int *d; 131 | CHECK_CUDA_ERROR(cudaMalloc(&d, SIZE)); 132 | CHECK_CUDA_ERROR(cudaMemcpy(d, h_in, SIZE, cudaMemcpyHostToDevice)); 133 | void *tensor_ptr = (void *)d; 134 | 135 | CUtensorMap tensor_map{}; 136 | // rank is the number of dimensions of the array. 137 | constexpr uint32_t rank = 2; 138 | uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; 139 | // The stride is the number of bytes to traverse from the first element of one 140 | // row to the next. It must be a multiple of 16. 141 | uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(int)}; 142 | // The box_size is the size of the shared memory buffer that is used as the 143 | // destination of a TMA transfer. 144 | uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; 145 | // The distance between elements in units of sizeof(element). A stride of 2 146 | // can be used to load only the real component of a complex-valued tensor, for 147 | // instance. 148 | uint32_t elem_stride[rank] = {1, 1}; 149 | 150 | // Create the tensor descriptor. 151 | CUresult res = cuTensorMapEncodeTiled( 152 | &tensor_map, // CUtensorMap *tensorMap, 153 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT32, 154 | rank, // cuuint32_t tensorRank, 155 | tensor_ptr, // void *globalAddress, 156 | size, // const cuuint64_t *globalDim, 157 | stride, // const cuuint64_t *globalStrides, 158 | box_size, // const cuuint32_t *boxDim, 159 | elem_stride, // const cuuint32_t *elementStrides, 160 | // Interleave patterns can be used to accelerate loading of values that 161 | // are less than 4 bytes long. 162 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 163 | // Swizzling can be used to avoid shared memory bank conflicts. 164 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 165 | // L2 Promotion can be used to widen the effect of a cache-policy to a 166 | // wider set of L2 cache lines. 167 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 168 | // Any element that is outside of bounds will be set to zero by the TMA 169 | // transfer. 170 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 171 | 172 | assert(res == CUDA_SUCCESS); 173 | 174 | dim3 blockDim(SMEM_WIDTH * SMEM_HEIGHT, 1, 1); 175 | dim3 gridDim(GMEM_WIDTH / SMEM_WIDTH, GMEM_HEIGHT / SMEM_HEIGHT, 1); 176 | 177 | kernel<<>>(tensor_map); 178 | 179 | CHECK_LAST_CUDA_ERROR(); 180 | CHECK_CUDA_ERROR(cudaMemcpy(h_out, d, SIZE, cudaMemcpyDeviceToHost)); 181 | 182 | std::cout << "Visualize Bank assignment:" << std::endl; 183 | // utils::printMatrix(h_out, GMEM_HEIGHT, GMEM_WIDTH); 184 | utils::printMatrixHeatmap32(h_out, GMEM_HEIGHT, GMEM_WIDTH); 185 | std::cout << std::endl; 186 | 187 | CHECK_CUDA_ERROR(cudaFree(d)); 188 | free(h_in); 189 | free(h_out); 190 | } -------------------------------------------------------------------------------- /transpose_naive.cu: -------------------------------------------------------------------------------- 1 | #include // CUtensormap 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using barrier = cuda::barrier; 9 | namespace cde = cuda::device::experimental; 10 | 11 | #include "utils.h" 12 | 13 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) 14 | void check(cudaError_t err, char const *func, char const *file, int line) { 15 | if (err != cudaSuccess) { 16 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 17 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 18 | std::exit(EXIT_FAILURE); 19 | } 20 | } 21 | 22 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__) 23 | void checkLast(char const *file, int line) { 24 | cudaError_t const err{cudaGetLastError()}; 25 | if (err != cudaSuccess) { 26 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 27 | std::cerr << cudaGetErrorString(err) << std::endl; 28 | std::exit(EXIT_FAILURE); 29 | } 30 | } 31 | 32 | template 33 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 34 | const __grid_constant__ CUtensorMap tensor_map_tr) { 35 | // The destination shared memory buffer of a bulk tensor operation should be 36 | // 128 byte aligned. 37 | __shared__ alignas(1024) float smem_buffer[BLOCK_SIZE * BLOCK_SIZE]; 38 | __shared__ alignas(1024) float smem_buffer_tr[BLOCK_SIZE * BLOCK_SIZE]; 39 | // Coordinates for upper left tile in GMEM. 40 | int x = blockIdx.x * BLOCK_SIZE; 41 | int y = blockIdx.y * BLOCK_SIZE; 42 | 43 | int col = threadIdx.x % BLOCK_SIZE; 44 | int row = threadIdx.x / BLOCK_SIZE; 45 | 46 | // Initialize shared memory barrier with the number of threads participating in 47 | // the barrier. 48 | #pragma nv_diag_suppress static_var_with_dynamic_init 49 | __shared__ barrier bar; 50 | 51 | if (threadIdx.x == 0) { 52 | // Initialize barrier. All `blockDim.x` threads in block participate. 53 | init(&bar, blockDim.x); 54 | // Make initialized barrier visible in async proxy. 55 | cde::fence_proxy_async_shared_cta(); 56 | } 57 | // Syncthreads so initialized barrier is visible to all threads. 58 | __syncthreads(); 59 | 60 | barrier::arrival_token token; 61 | if (threadIdx.x == 0) { 62 | // Initiate bulk tensor copy. 63 | cde::cp_async_bulk_tensor_2d_global_to_shared(&smem_buffer, &tensor_map, x, 64 | y, bar); 65 | // Arrive on the barrier and tell how many bytes are expected to come in. 66 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 67 | } else { 68 | // Other threads just arrive. 69 | token = bar.arrive(); 70 | } 71 | // Wait for the data to have arrived. 72 | bar.wait(std::move(token)); 73 | 74 | // Transpose tile. 75 | smem_buffer_tr[col * BLOCK_SIZE + row] = smem_buffer[row * BLOCK_SIZE + col]; 76 | 77 | // Wait for shared memory writes to be visible to TMA engine. 78 | cde::fence_proxy_async_shared_cta(); 79 | __syncthreads(); 80 | // After syncthreads, writes by all threads are visible to TMA engine. 81 | 82 | // Initiate TMA transfer to copy shared memory to global memory 83 | if (threadIdx.x == 0) { 84 | // Transpose tile inside matrix 85 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map_tr, y, x, 86 | &smem_buffer_tr); 87 | // Wait for TMA transfer to have finished reading shared memory. 88 | // Create a "bulk async-group" out of the previous bulk copy operation. 89 | cde::cp_async_bulk_commit_group(); 90 | // Wait for the group to have completed reading from shared memory. 91 | cde::cp_async_bulk_wait_group_read<0>(); 92 | } 93 | 94 | // Destroy barrier. This invalidates the memory region of the barrier. If 95 | // further computations were to take place in the kernel, this allows the 96 | // memory location of the shared memory barrier to be reused. 97 | if (threadIdx.x == 0) { 98 | (&bar)->~barrier(); 99 | } 100 | } 101 | 102 | int main() { 103 | const int GMEM_WIDTH = 32768; 104 | const int GMEM_HEIGHT = 32768; 105 | const int BLOCK_SIZE = 32; 106 | const int SMEM_WIDTH = BLOCK_SIZE; 107 | const int SMEM_HEIGHT = BLOCK_SIZE; 108 | const size_t SIZE = GMEM_HEIGHT * GMEM_WIDTH * sizeof(float); 109 | 110 | float *h_in = new float[GMEM_HEIGHT * GMEM_WIDTH]; 111 | float *h_out = new float[GMEM_HEIGHT * GMEM_WIDTH]; 112 | 113 | // Initialize with normal distribution 114 | std::default_random_engine generator(42); 115 | std::normal_distribution distribution(0.0, 1.0); 116 | 117 | for (int i = 0; i < GMEM_HEIGHT * GMEM_WIDTH; ++i) { 118 | h_in[i] = distribution(generator); 119 | } 120 | 121 | float *d; 122 | float *d_tr; 123 | CHECK_CUDA_ERROR(cudaMalloc(&d, SIZE)); 124 | CHECK_CUDA_ERROR(cudaMemcpy(d, h_in, SIZE, cudaMemcpyHostToDevice)); 125 | void *tensor_ptr = (void *)d; 126 | CHECK_CUDA_ERROR(cudaMalloc(&d_tr, SIZE)); 127 | void *tensor_ptr_tr = (void *)d_tr; 128 | 129 | CUtensorMap tensor_map{}; 130 | CUtensorMap tensor_map_tr{}; 131 | constexpr uint32_t rank = 2; 132 | uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; 133 | uint64_t size_tr[rank] = {GMEM_HEIGHT, GMEM_WIDTH}; 134 | uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(float)}; 135 | uint64_t stride_tr[rank - 1] = {GMEM_HEIGHT * sizeof(float)}; 136 | uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; 137 | uint32_t box_size_tr[rank] = {SMEM_HEIGHT, SMEM_WIDTH}; 138 | uint32_t elem_stride[rank] = {1, 1}; 139 | 140 | // Create the tensor descriptor. 141 | CUresult res = cuTensorMapEncodeTiled( 142 | &tensor_map, // CUtensorMap *tensorMap, 143 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 144 | rank, // cuuint32_t tensorRank, 145 | tensor_ptr, // void *globalAddress, 146 | size, // const cuuint64_t *globalDim, 147 | stride, // const cuuint64_t *globalStrides, 148 | box_size, // const cuuint32_t *boxDim, 149 | elem_stride, // const cuuint32_t *elementStrides, 150 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 151 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, 152 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 153 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 154 | 155 | assert(res == CUDA_SUCCESS); 156 | 157 | CUresult res_tr = cuTensorMapEncodeTiled( 158 | &tensor_map_tr, // CUtensorMap *tensorMap, 159 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 160 | rank, // cuuint32_t tensorRank, 161 | tensor_ptr_tr, // void *globalAddress, 162 | size_tr, // const cuuint64_t *globalDim, 163 | stride, // const cuuint64_t *globalStrides, 164 | box_size_tr, // const cuuint32_t *boxDim, 165 | elem_stride, // const cuuint32_t *elementStrides, 166 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 167 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, 168 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 169 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 170 | 171 | assert(res_tr == CUDA_SUCCESS); 172 | 173 | dim3 blockDim(SMEM_WIDTH * SMEM_HEIGHT, 1, 1); 174 | dim3 gridDim(GMEM_WIDTH / SMEM_WIDTH, GMEM_HEIGHT / SMEM_HEIGHT, 1); 175 | 176 | kernel<<>>(tensor_map, tensor_map_tr); 177 | 178 | CHECK_LAST_CUDA_ERROR(); 179 | CHECK_CUDA_ERROR(cudaMemcpy(h_out, d_tr, SIZE, cudaMemcpyDeviceToHost)); 180 | 181 | const float epsilon = 1e-5f; 182 | for (int x = 0; x < GMEM_HEIGHT; x++) { 183 | for (int y = 0; y < GMEM_WIDTH; y++) { 184 | float expected = h_in[x * GMEM_WIDTH + y]; 185 | float actual = h_out[y * GMEM_HEIGHT + x]; 186 | if (std::fabs(expected - actual) > epsilon) { 187 | std::cout << "Error at position (" << x << "," << y << "): expected " 188 | << expected << " but got " << actual << std::endl; 189 | return -1; 190 | } 191 | } 192 | } 193 | 194 | std::cout << "Passed" << std::endl; 195 | 196 | int numRounds = 10000; 197 | size_t numCrossMemoryBound = 2 * SIZE; 198 | cudaEvent_t start, stop; 199 | float time; 200 | 201 | CHECK_CUDA_ERROR(cudaEventCreate(&start)); 202 | CHECK_CUDA_ERROR(cudaEventCreate(&stop)); 203 | 204 | CHECK_CUDA_ERROR(cudaEventRecord(start)); 205 | for (int i = 0; i < numRounds; i++) { 206 | kernel<<>>(tensor_map, tensor_map_tr); 207 | } 208 | CHECK_CUDA_ERROR(cudaEventRecord(stop)); 209 | CHECK_CUDA_ERROR(cudaEventSynchronize(stop)); 210 | CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop)); 211 | CHECK_LAST_CUDA_ERROR(); 212 | 213 | float latency = time / numRounds; 214 | float bandwidth = (numCrossMemoryBound / latency) / 1e6; 215 | 216 | std::cout << "Latency = " << latency << " ms" << std::endl; 217 | std::cout << "Bandwidth = " << bandwidth << " GB/s" << std::endl; 218 | std::cout << "% of max = " << bandwidth / 3300 * 100 << " %" << std::endl; 219 | 220 | CHECK_CUDA_ERROR(cudaFree(d)); 221 | free(h_in); 222 | free(h_out); 223 | } -------------------------------------------------------------------------------- /transpose_swizzle.cu: -------------------------------------------------------------------------------- 1 | #include // CUtensormap 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using barrier = cuda::barrier; 9 | namespace cde = cuda::device::experimental; 10 | 11 | #include "utils.h" 12 | 13 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) 14 | void check(cudaError_t err, char const *func, char const *file, int line) { 15 | if (err != cudaSuccess) { 16 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 17 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 18 | std::exit(EXIT_FAILURE); 19 | } 20 | } 21 | 22 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__) 23 | void checkLast(char const *file, int line) { 24 | cudaError_t const err{cudaGetLastError()}; 25 | if (err != cudaSuccess) { 26 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 27 | std::cerr << cudaGetErrorString(err) << std::endl; 28 | std::exit(EXIT_FAILURE); 29 | } 30 | } 31 | 32 | template 33 | __device__ int calculate_col_swizzle(int row, int col) { 34 | int i16 = (row * BLOCK_SIZE + col) * 4 >> 4; 35 | int y16 = i16 >> 3; 36 | int x16 = i16 & 7; 37 | int x16_swz = y16 ^ x16; 38 | return ((x16_swz * 4) & (BLOCK_SIZE - 1)) + (col & 3); 39 | } 40 | 41 | template 42 | __device__ int calculate_row_swizzle(int row, int col) { 43 | int i16_tr = (col * BLOCK_SIZE + row) * 4 >> 4; 44 | int y16_tr = i16_tr >> 3; 45 | int x16_tr = i16_tr & 7; 46 | int x16_swz_tr = y16_tr ^ x16_tr; 47 | return ((x16_swz_tr * 4) & (BLOCK_SIZE - 1)) + (row & 3); 48 | } 49 | 50 | template 51 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 52 | const __grid_constant__ CUtensorMap tensor_map_tr) { 53 | // The destination shared memory buffer of a bulk tensor operation should be 54 | // 128 byte aligned. 55 | __shared__ alignas(1024) float smem_buffer[BLOCK_SIZE * BLOCK_SIZE]; 56 | __shared__ alignas(1024) float smem_buffer_tr[BLOCK_SIZE * BLOCK_SIZE]; 57 | // Coordinates for upper left tile in GMEM. 58 | int x = blockIdx.x * BLOCK_SIZE; 59 | int y = blockIdx.y * BLOCK_SIZE; 60 | 61 | int col = threadIdx.x & (BLOCK_SIZE - 1); 62 | int row = threadIdx.x >> LOG_BLOCK; 63 | 64 | int col_swizzle = calculate_col_swizzle(row, col); 65 | 66 | int row_swizzle = calculate_row_swizzle(row, col); 67 | 68 | // Initialize shared memory barrier with the number of threads participating in 69 | // the barrier. 70 | #pragma nv_diag_suppress static_var_with_dynamic_init 71 | __shared__ barrier bar; 72 | 73 | if (threadIdx.x == 0) { 74 | // Initialize barrier. All `blockDim.x` threads in block participate. 75 | init(&bar, blockDim.x); 76 | // Make initialized barrier visible in async proxy. 77 | cde::fence_proxy_async_shared_cta(); 78 | } 79 | // Syncthreads so initialized barrier is visible to all threads. 80 | __syncthreads(); 81 | 82 | barrier::arrival_token token; 83 | if (threadIdx.x == 0) { 84 | // Initiate bulk tensor copy. 85 | cde::cp_async_bulk_tensor_2d_global_to_shared(&smem_buffer, &tensor_map, x, 86 | y, bar); 87 | // Arrive on the barrier and tell how many bytes are expected to come in. 88 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 89 | } else { 90 | // Other threads just arrive. 91 | token = bar.arrive(); 92 | } 93 | // Wait for the data to have arrived. 94 | bar.wait(std::move(token)); 95 | 96 | // Transpose tile. 97 | smem_buffer_tr[col * BLOCK_SIZE + row_swizzle] = 98 | smem_buffer[row * BLOCK_SIZE + col_swizzle]; 99 | 100 | // Wait for shared memory writes to be visible to TMA engine. 101 | cde::fence_proxy_async_shared_cta(); 102 | __syncthreads(); 103 | // After syncthreads, writes by all threads are visible to TMA engine. 104 | 105 | // Initiate TMA transfer to copy shared memory to global memory 106 | if (threadIdx.x == 0) { 107 | // Transpose tile inside matrix 108 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map_tr, y, x, 109 | &smem_buffer_tr); 110 | // Wait for TMA transfer to have finished reading shared memory. 111 | // Create a "bulk async-group" out of the previous bulk copy operation. 112 | cde::cp_async_bulk_commit_group(); 113 | // Wait for the group to have completed reading from shared memory. 114 | cde::cp_async_bulk_wait_group_read<0>(); 115 | } 116 | 117 | // Destroy barrier. This invalidates the memory region of the barrier. If 118 | // further computations were to take place in the kernel, this allows the 119 | // memory location of the shared memory barrier to be reused. 120 | if (threadIdx.x == 0) { 121 | (&bar)->~barrier(); 122 | } 123 | } 124 | 125 | int main() { 126 | const int GMEM_WIDTH = 32768; 127 | const int GMEM_HEIGHT = 32768; 128 | const int BLOCK_SIZE = 32; 129 | const int LOG_BLOCK = 5; 130 | const int SMEM_WIDTH = BLOCK_SIZE; 131 | const int SMEM_HEIGHT = BLOCK_SIZE; 132 | const size_t SIZE = GMEM_HEIGHT * GMEM_WIDTH * sizeof(float); 133 | 134 | float *h_in = new float[GMEM_HEIGHT * GMEM_WIDTH]; 135 | float *h_out = new float[GMEM_HEIGHT * GMEM_WIDTH]; 136 | 137 | // Initialize with normal distribution 138 | std::default_random_engine generator(42); 139 | std::normal_distribution distribution(0.0, 1.0); 140 | 141 | for (int i = 0; i < GMEM_HEIGHT * GMEM_WIDTH; ++i) { 142 | h_in[i] = distribution(generator); 143 | } 144 | 145 | float *d; 146 | float *d_tr; 147 | CHECK_CUDA_ERROR(cudaMalloc(&d, SIZE)); 148 | CHECK_CUDA_ERROR(cudaMemcpy(d, h_in, SIZE, cudaMemcpyHostToDevice)); 149 | void *tensor_ptr = (void *)d; 150 | CHECK_CUDA_ERROR(cudaMalloc(&d_tr, SIZE)); 151 | void *tensor_ptr_tr = (void *)d_tr; 152 | 153 | CUtensorMap tensor_map{}; 154 | CUtensorMap tensor_map_tr{}; 155 | // rank is the number of dimensions of the array. 156 | constexpr uint32_t rank = 2; 157 | uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; 158 | uint64_t size_tr[rank] = {GMEM_HEIGHT, GMEM_WIDTH}; 159 | // The stride is the number of bytes to traverse from the first element of one 160 | // row to the next. It must be a multiple of 16. 161 | uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(float)}; 162 | uint64_t stride_tr[rank - 1] = {GMEM_HEIGHT * sizeof(float)}; 163 | // The box_size is the size of the shared memory buffer that is used as the 164 | // destination of a TMA transfer. 165 | uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; 166 | uint32_t box_size_tr[rank] = {SMEM_HEIGHT, SMEM_WIDTH}; 167 | // The distance between elements in units of sizeof(element). A stride of 2 168 | // can be used to load only the real component of a complex-valued tensor, for 169 | // instance. 170 | uint32_t elem_stride[rank] = {1, 1}; 171 | 172 | // Create the tensor descriptor. 173 | CUresult res = cuTensorMapEncodeTiled( 174 | &tensor_map, // CUtensorMap *tensorMap, 175 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 176 | rank, // cuuint32_t tensorRank, 177 | tensor_ptr, // void *globalAddress, 178 | size, // const cuuint64_t *globalDim, 179 | stride, // const cuuint64_t *globalStrides, 180 | box_size, // const cuuint32_t *boxDim, 181 | elem_stride, // const cuuint32_t *elementStrides, 182 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 183 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 184 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 185 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 186 | 187 | assert(res == CUDA_SUCCESS); 188 | 189 | CUresult res_tr = cuTensorMapEncodeTiled( 190 | &tensor_map_tr, // CUtensorMap *tensorMap, 191 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 192 | rank, // cuuint32_t tensorRank, 193 | tensor_ptr_tr, // void *globalAddress, 194 | size_tr, // const cuuint64_t *globalDim, 195 | stride, // const cuuint64_t *globalStrides, 196 | box_size_tr, // const cuuint32_t *boxDim, 197 | elem_stride, // const cuuint32_t *elementStrides, 198 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 199 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 200 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 201 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 202 | 203 | assert(res_tr == CUDA_SUCCESS); 204 | 205 | dim3 blockDim(SMEM_WIDTH * SMEM_HEIGHT, 1, 1); 206 | dim3 gridDim(GMEM_WIDTH / SMEM_WIDTH, GMEM_HEIGHT / SMEM_HEIGHT, 1); 207 | 208 | kernel 209 | <<>>(tensor_map, tensor_map_tr); 210 | 211 | CHECK_LAST_CUDA_ERROR(); 212 | CHECK_CUDA_ERROR(cudaMemcpy(h_out, d_tr, SIZE, cudaMemcpyDeviceToHost)); 213 | 214 | const float epsilon = 1e-5f; 215 | for (int x = 0; x < GMEM_HEIGHT; x++) { 216 | for (int y = 0; y < GMEM_WIDTH; y++) { 217 | float expected = h_in[x * GMEM_WIDTH + y]; 218 | float actual = h_out[y * GMEM_HEIGHT + x]; 219 | if (std::fabs(expected - actual) > epsilon) { 220 | std::cout << "Error at position (" << x << "," << y << "): expected " 221 | << expected << " but got " << actual << std::endl; 222 | return -1; 223 | } 224 | } 225 | } 226 | 227 | std::cout << "Passed" << std::endl; 228 | 229 | int numRounds = 10000; 230 | size_t numCrossMemoryBound = 2 * SIZE; 231 | cudaEvent_t start, stop; 232 | float time; 233 | 234 | CHECK_CUDA_ERROR(cudaEventCreate(&start)); 235 | CHECK_CUDA_ERROR(cudaEventCreate(&stop)); 236 | 237 | CHECK_CUDA_ERROR(cudaEventRecord(start)); 238 | for (int i = 0; i < numRounds; i++) { 239 | kernel 240 | <<>>(tensor_map, tensor_map_tr); 241 | } 242 | CHECK_CUDA_ERROR(cudaEventRecord(stop)); 243 | CHECK_CUDA_ERROR(cudaEventSynchronize(stop)); 244 | CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop)); 245 | CHECK_LAST_CUDA_ERROR(); 246 | 247 | float latency = time / numRounds; 248 | float bandwidth = (numCrossMemoryBound / latency) / 1e6; 249 | 250 | std::cout << "Latency = " << latency << " ms" << std::endl; 251 | std::cout << "Bandwidth = " << bandwidth << " GB/s" << std::endl; 252 | std::cout << "% of max = " << bandwidth / 3300 * 100 << " %" << std::endl; 253 | 254 | CHECK_CUDA_ERROR(cudaFree(d)); 255 | free(h_in); 256 | free(h_out); 257 | } -------------------------------------------------------------------------------- /transpose_swizzle_batched.cu: -------------------------------------------------------------------------------- 1 | #include // CUtensormap 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using barrier = cuda::barrier; 9 | namespace cde = cuda::device::experimental; 10 | 11 | #include "utils.h" 12 | 13 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) 14 | void check(cudaError_t err, char const *func, char const *file, int line) { 15 | if (err != cudaSuccess) { 16 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 17 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 18 | std::exit(EXIT_FAILURE); 19 | } 20 | } 21 | 22 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__) 23 | void checkLast(char const *file, int line) { 24 | cudaError_t const err{cudaGetLastError()}; 25 | if (err != cudaSuccess) { 26 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 27 | std::cerr << cudaGetErrorString(err) << std::endl; 28 | std::exit(EXIT_FAILURE); 29 | } 30 | } 31 | 32 | template 33 | __device__ int calculate_col_swizzle(int row, int col) { 34 | int i16 = (row * BLOCK_SIZE + col) * 4 >> 4; 35 | int y16 = i16 >> 3; 36 | int x16 = i16 & 7; 37 | int x16_swz = y16 ^ x16; 38 | return ((x16_swz * 4) & (BLOCK_SIZE - 1)) + (col & 3); 39 | } 40 | 41 | template 42 | __device__ int calculate_row_swizzle(int row, int col) { 43 | int i16_tr = (col * BLOCK_SIZE + row) * 4 >> 4; 44 | int y16_tr = i16_tr >> 3; 45 | int x16_tr = i16_tr & 7; 46 | int x16_swz_tr = y16_tr ^ x16_tr; 47 | return ((x16_swz_tr * 4) & (BLOCK_SIZE - 1)) + (row & 3); 48 | } 49 | 50 | template 51 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 52 | const __grid_constant__ CUtensorMap tensor_map_tr) { 53 | // The destination shared memory buffer of a bulk tensor operation should be 54 | // 128 byte aligned. 55 | __shared__ alignas(1024) float smem_buffer[BLOCK_SIZE * BLOCK_SIZE]; 56 | __shared__ alignas(1024) float smem_buffer_tr[BLOCK_SIZE * BLOCK_SIZE]; 57 | // Coordinates for upper left tile in GMEM. 58 | int x = blockIdx.x * BLOCK_SIZE; 59 | int y = blockIdx.y * BLOCK_SIZE; 60 | 61 | int col = (threadIdx.x & (BLOCK_SIZE / BATCH_SIZE - 1)) * BATCH_SIZE; 62 | int row = threadIdx.x >> (LOG_BLOCK - LOG_BATCH_SIZE); 63 | 64 | // Initialize shared memory barrier with the number of threads participating in 65 | // the barrier. 66 | #pragma nv_diag_suppress static_var_with_dynamic_init 67 | __shared__ barrier bar; 68 | 69 | if (threadIdx.x == 0) { 70 | // Initialize barrier. All `blockDim.x` threads in block participate. 71 | init(&bar, blockDim.x); 72 | // Make initialized barrier visible in async proxy. 73 | cde::fence_proxy_async_shared_cta(); 74 | } 75 | // Syncthreads so initialized barrier is visible to all threads. 76 | __syncthreads(); 77 | 78 | barrier::arrival_token token; 79 | if (threadIdx.x == 0) { 80 | // Initiate bulk tensor copy. 81 | cde::cp_async_bulk_tensor_2d_global_to_shared(&smem_buffer, &tensor_map, x, 82 | y, bar); 83 | // Arrive on the barrier and tell how many bytes are expected to come in. 84 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 85 | } else { 86 | // Other threads just arrive. 87 | token = bar.arrive(); 88 | } 89 | // Wait for the data to have arrived. 90 | bar.wait(std::move(token)); 91 | 92 | // Transpose tile. 93 | #pragma unroll 94 | for (int j = 0; j < BATCH_SIZE; j++) { 95 | int col_ = col + j; 96 | int row_ = row; 97 | int col_swizzle = calculate_col_swizzle(row_, col_); 98 | int row_swizzle = calculate_row_swizzle(row_, col_); 99 | 100 | smem_buffer_tr[col_ * BLOCK_SIZE + row_swizzle] = 101 | smem_buffer[row_ * BLOCK_SIZE + col_swizzle]; 102 | } 103 | // Wait for shared memory writes to be visible to TMA engine. 104 | cde::fence_proxy_async_shared_cta(); 105 | __syncthreads(); 106 | // After syncthreads, writes by all threads are visible to TMA engine. 107 | 108 | // Initiate TMA transfer to copy shared memory to global memory 109 | if (threadIdx.x == 0) { 110 | // Transpose tile inside matrix 111 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map_tr, y, x, 112 | &smem_buffer_tr); 113 | // Wait for TMA transfer to have finished reading shared memory. 114 | // Create a "bulk async-group" out of the previous bulk copy operation. 115 | cde::cp_async_bulk_commit_group(); 116 | // Wait for the group to have completed reading from shared memory. 117 | cde::cp_async_bulk_wait_group_read<0>(); 118 | } 119 | 120 | // Destroy barrier. This invalidates the memory region of the barrier. If 121 | // further computations were to take place in the kernel, this allows the 122 | // memory location of the shared memory barrier to be reused. 123 | if (threadIdx.x == 0) { 124 | (&bar)->~barrier(); 125 | } 126 | } 127 | 128 | int main() { 129 | const int GMEM_WIDTH = 32768; 130 | const int GMEM_HEIGHT = 32768; 131 | const int BLOCK_SIZE = 32; 132 | const int LOG_BLOCK = 5; 133 | const int BATCH_SIZE = 8; 134 | const int LOG_BATCH_SIZE = 3; 135 | const int SMEM_WIDTH = BLOCK_SIZE; 136 | const int SMEM_HEIGHT = BLOCK_SIZE; 137 | const size_t SIZE = GMEM_HEIGHT * GMEM_WIDTH * sizeof(float); 138 | 139 | float *h_in = new float[GMEM_HEIGHT * GMEM_WIDTH]; 140 | float *h_out = new float[GMEM_HEIGHT * GMEM_WIDTH]; 141 | 142 | // Initialize with normal distribution 143 | std::default_random_engine generator(42); 144 | std::normal_distribution distribution(0.0, 1.0); 145 | 146 | for (int i = 0; i < GMEM_HEIGHT * GMEM_WIDTH; ++i) { 147 | h_in[i] = distribution(generator); 148 | } 149 | 150 | float *d; 151 | float *d_tr; 152 | CHECK_CUDA_ERROR(cudaMalloc(&d, SIZE)); 153 | CHECK_CUDA_ERROR(cudaMemcpy(d, h_in, SIZE, cudaMemcpyHostToDevice)); 154 | void *tensor_ptr = (void *)d; 155 | CHECK_CUDA_ERROR(cudaMalloc(&d_tr, SIZE)); 156 | void *tensor_ptr_tr = (void *)d_tr; 157 | 158 | CUtensorMap tensor_map{}; 159 | CUtensorMap tensor_map_tr{}; 160 | // rank is the number of dimensions of the array. 161 | constexpr uint32_t rank = 2; 162 | uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; 163 | uint64_t size_tr[rank] = {GMEM_HEIGHT, GMEM_WIDTH}; 164 | // The stride is the number of bytes to traverse from the first element of one 165 | // row to the next. It must be a multiple of 16. 166 | uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(float)}; 167 | uint64_t stride_tr[rank - 1] = {GMEM_HEIGHT * sizeof(float)}; 168 | // The box_size is the size of the shared memory buffer that is used as the 169 | // destination of a TMA transfer. 170 | uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; 171 | uint32_t box_size_tr[rank] = {SMEM_HEIGHT, SMEM_WIDTH}; 172 | // The distance between elements in units of sizeof(element). A stride of 2 173 | // can be used to load only the real component of a complex-valued tensor, for 174 | // instance. 175 | uint32_t elem_stride[rank] = {1, 1}; 176 | 177 | // Create the tensor descriptor. 178 | CUresult res = cuTensorMapEncodeTiled( 179 | &tensor_map, // CUtensorMap *tensorMap, 180 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 181 | rank, // cuuint32_t tensorRank, 182 | tensor_ptr, // void *globalAddress, 183 | size, // const cuuint64_t *globalDim, 184 | stride, // const cuuint64_t *globalStrides, 185 | box_size, // const cuuint32_t *boxDim, 186 | elem_stride, // const cuuint32_t *elementStrides, 187 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 188 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 189 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 190 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 191 | 192 | assert(res == CUDA_SUCCESS); 193 | 194 | CUresult res_tr = cuTensorMapEncodeTiled( 195 | &tensor_map_tr, // CUtensorMap *tensorMap, 196 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 197 | rank, // cuuint32_t tensorRank, 198 | tensor_ptr_tr, // void *globalAddress, 199 | size_tr, // const cuuint64_t *globalDim, 200 | stride, // const cuuint64_t *globalStrides, 201 | box_size_tr, // const cuuint32_t *boxDim, 202 | elem_stride, // const cuuint32_t *elementStrides, 203 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 204 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 205 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 206 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 207 | 208 | assert(res_tr == CUDA_SUCCESS); 209 | 210 | dim3 blockDim((SMEM_WIDTH * SMEM_HEIGHT) / BATCH_SIZE, 1, 1); 211 | dim3 gridDim(GMEM_WIDTH / SMEM_WIDTH, GMEM_HEIGHT / SMEM_HEIGHT, 1); 212 | 213 | kernel 214 | <<>>(tensor_map, tensor_map_tr); 215 | 216 | CHECK_LAST_CUDA_ERROR(); 217 | CHECK_CUDA_ERROR(cudaMemcpy(h_out, d_tr, SIZE, cudaMemcpyDeviceToHost)); 218 | 219 | const float epsilon = 1e-5f; 220 | for (int x = 0; x < GMEM_HEIGHT; x++) { 221 | for (int y = 0; y < GMEM_WIDTH; y++) { 222 | float expected = h_in[x * GMEM_WIDTH + y]; 223 | float actual = h_out[y * GMEM_HEIGHT + x]; 224 | if (std::fabs(expected - actual) > epsilon) { 225 | std::cout << "Error at position (" << x << "," << y << "): expected " 226 | << expected << " but got " << actual << std::endl; 227 | return -1; 228 | } 229 | } 230 | } 231 | 232 | std::cout << "Passed" << std::endl; 233 | 234 | int numRounds = 10000; 235 | size_t numCrossMemoryBound = 2 * SIZE; 236 | cudaEvent_t start, stop; 237 | float time; 238 | 239 | CHECK_CUDA_ERROR(cudaEventCreate(&start)); 240 | CHECK_CUDA_ERROR(cudaEventCreate(&stop)); 241 | 242 | CHECK_CUDA_ERROR(cudaEventRecord(start)); 243 | for (int i = 0; i < numRounds; i++) { 244 | kernel 245 | <<>>(tensor_map, tensor_map_tr); 246 | } 247 | CHECK_CUDA_ERROR(cudaEventRecord(stop)); 248 | CHECK_CUDA_ERROR(cudaEventSynchronize(stop)); 249 | CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop)); 250 | CHECK_LAST_CUDA_ERROR(); 251 | 252 | float latency = time / numRounds; 253 | float bandwidth = (numCrossMemoryBound / latency) / 1e6; 254 | 255 | std::cout << "Latency = " << latency << " ms" << std::endl; 256 | std::cout << "Bandwidth = " << bandwidth << " GB/s" << std::endl; 257 | std::cout << "% of max = " << bandwidth / 3300 * 100 << " %" << std::endl; 258 | 259 | CHECK_CUDA_ERROR(cudaFree(d)); 260 | free(h_in); 261 | free(h_out); 262 | } -------------------------------------------------------------------------------- /transpose_swizzle_batched_for_profile.cu: -------------------------------------------------------------------------------- 1 | #include // CUtensormap 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using barrier = cuda::barrier; 9 | namespace cde = cuda::device::experimental; 10 | 11 | #include "utils.h" 12 | 13 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) 14 | void check(cudaError_t err, char const *func, char const *file, int line) { 15 | if (err != cudaSuccess) { 16 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 17 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 18 | std::exit(EXIT_FAILURE); 19 | } 20 | } 21 | 22 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__) 23 | void checkLast(char const *file, int line) { 24 | cudaError_t const err{cudaGetLastError()}; 25 | if (err != cudaSuccess) { 26 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 27 | std::cerr << cudaGetErrorString(err) << std::endl; 28 | std::exit(EXIT_FAILURE); 29 | } 30 | } 31 | 32 | template 33 | __device__ int calculate_col_swizzle(int row, int col) { 34 | int i16 = (row * BLOCK_SIZE + col) * 4 >> 4; 35 | int y16 = i16 >> 3; 36 | int x16 = i16 & 7; 37 | int x16_swz = y16 ^ x16; 38 | return ((x16_swz * 4) & (BLOCK_SIZE - 1)) + (col & 3); 39 | } 40 | 41 | template 42 | __device__ int calculate_row_swizzle(int row, int col) { 43 | int i16_tr = (col * BLOCK_SIZE + row) * 4 >> 4; 44 | int y16_tr = i16_tr >> 3; 45 | int x16_tr = i16_tr & 7; 46 | int x16_swz_tr = y16_tr ^ x16_tr; 47 | return ((x16_swz_tr * 4) & (BLOCK_SIZE - 1)) + (row & 3); 48 | } 49 | 50 | template 51 | __global__ void kernel(const __grid_constant__ CUtensorMap tensor_map, 52 | const __grid_constant__ CUtensorMap tensor_map_tr) { 53 | // The destination shared memory buffer of a bulk tensor operation should be 54 | // 128 byte aligned. 55 | __shared__ alignas(1024) float smem_buffer[BLOCK_SIZE * BLOCK_SIZE]; 56 | __shared__ alignas(1024) float smem_buffer_tr[BLOCK_SIZE * BLOCK_SIZE]; 57 | // Coordinates for upper left tile in GMEM. 58 | int x = blockIdx.x * BLOCK_SIZE; 59 | int y = blockIdx.y * BLOCK_SIZE; 60 | 61 | int col = (threadIdx.x & (BLOCK_SIZE / BATCH_SIZE - 1)) * BATCH_SIZE; 62 | int row = threadIdx.x >> (LOG_BLOCK - LOG_BATCH_SIZE); 63 | 64 | // Initialize shared memory barrier with the number of threads participating in 65 | // the barrier. 66 | #pragma nv_diag_suppress static_var_with_dynamic_init 67 | __shared__ barrier bar; 68 | 69 | if (threadIdx.x == 0) { 70 | // Initialize barrier. All `blockDim.x` threads in block participate. 71 | init(&bar, blockDim.x); 72 | // Make initialized barrier visible in async proxy. 73 | cde::fence_proxy_async_shared_cta(); 74 | } 75 | // Syncthreads so initialized barrier is visible to all threads. 76 | __syncthreads(); 77 | 78 | barrier::arrival_token token; 79 | if (threadIdx.x == 0) { 80 | // Initiate bulk tensor copy. 81 | cde::cp_async_bulk_tensor_2d_global_to_shared(&smem_buffer, &tensor_map, x, 82 | y, bar); 83 | // Arrive on the barrier and tell how many bytes are expected to come in. 84 | token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); 85 | } else { 86 | // Other threads just arrive. 87 | token = bar.arrive(); 88 | } 89 | // Wait for the data to have arrived. 90 | bar.wait(std::move(token)); 91 | 92 | // Transpose tile. 93 | #pragma unroll 94 | for (int j = 0; j < BATCH_SIZE; j++) { 95 | int col_ = col + j; 96 | int row_ = row; 97 | int col_swizzle = calculate_col_swizzle(row_, col_); 98 | int row_swizzle = calculate_row_swizzle(row_, col_); 99 | 100 | smem_buffer_tr[col_ * BLOCK_SIZE + row_swizzle] = 101 | smem_buffer[row_ * BLOCK_SIZE + col_swizzle]; 102 | } 103 | // Wait for shared memory writes to be visible to TMA engine. 104 | cde::fence_proxy_async_shared_cta(); 105 | __syncthreads(); 106 | // After syncthreads, writes by all threads are visible to TMA engine. 107 | 108 | // Initiate TMA transfer to copy shared memory to global memory 109 | if (threadIdx.x == 0) { 110 | // Transpose tile inside matrix 111 | cde::cp_async_bulk_tensor_2d_shared_to_global(&tensor_map_tr, y, x, 112 | &smem_buffer_tr); 113 | // Wait for TMA transfer to have finished reading shared memory. 114 | // Create a "bulk async-group" out of the previous bulk copy operation. 115 | cde::cp_async_bulk_commit_group(); 116 | // Wait for the group to have completed reading from shared memory. 117 | cde::cp_async_bulk_wait_group_read<0>(); 118 | } 119 | 120 | // Destroy barrier. This invalidates the memory region of the barrier. If 121 | // further computations were to take place in the kernel, this allows the 122 | // memory location of the shared memory barrier to be reused. 123 | if (threadIdx.x == 0) { 124 | (&bar)->~barrier(); 125 | } 126 | } 127 | 128 | int main() { 129 | const int GMEM_WIDTH = 32768; 130 | const int GMEM_HEIGHT = 32768; 131 | const int BLOCK_SIZE = 32; 132 | const int LOG_BLOCK = 5; 133 | const int BATCH_SIZE = 8; 134 | const int LOG_BATCH_SIZE = 3; 135 | const int SMEM_WIDTH = BLOCK_SIZE; 136 | const int SMEM_HEIGHT = BLOCK_SIZE; 137 | const size_t SIZE = GMEM_HEIGHT * GMEM_WIDTH * sizeof(float); 138 | 139 | float *h_in = new float[GMEM_HEIGHT * GMEM_WIDTH]; 140 | float *h_out = new float[GMEM_HEIGHT * GMEM_WIDTH]; 141 | 142 | // Initialize with normal distribution 143 | std::default_random_engine generator(42); 144 | std::normal_distribution distribution(0.0, 1.0); 145 | 146 | for (int i = 0; i < GMEM_HEIGHT * GMEM_WIDTH; ++i) { 147 | h_in[i] = distribution(generator); 148 | } 149 | 150 | float *d; 151 | float *d_tr; 152 | CHECK_CUDA_ERROR(cudaMalloc(&d, SIZE)); 153 | CHECK_CUDA_ERROR(cudaMemcpy(d, h_in, SIZE, cudaMemcpyHostToDevice)); 154 | void *tensor_ptr = (void *)d; 155 | CHECK_CUDA_ERROR(cudaMalloc(&d_tr, SIZE)); 156 | void *tensor_ptr_tr = (void *)d_tr; 157 | 158 | CUtensorMap tensor_map{}; 159 | CUtensorMap tensor_map_tr{}; 160 | // rank is the number of dimensions of the array. 161 | constexpr uint32_t rank = 2; 162 | uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; 163 | uint64_t size_tr[rank] = {GMEM_HEIGHT, GMEM_WIDTH}; 164 | // The stride is the number of bytes to traverse from the first element of one 165 | // row to the next. It must be a multiple of 16. 166 | uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(float)}; 167 | uint64_t stride_tr[rank - 1] = {GMEM_HEIGHT * sizeof(float)}; 168 | // The box_size is the size of the shared memory buffer that is used as the 169 | // destination of a TMA transfer. 170 | uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; 171 | uint32_t box_size_tr[rank] = {SMEM_HEIGHT, SMEM_WIDTH}; 172 | // The distance between elements in units of sizeof(element). A stride of 2 173 | // can be used to load only the real component of a complex-valued tensor, for 174 | // instance. 175 | uint32_t elem_stride[rank] = {1, 1}; 176 | 177 | // Create the tensor descriptor. 178 | CUresult res = cuTensorMapEncodeTiled( 179 | &tensor_map, // CUtensorMap *tensorMap, 180 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 181 | rank, // cuuint32_t tensorRank, 182 | tensor_ptr, // void *globalAddress, 183 | size, // const cuuint64_t *globalDim, 184 | stride, // const cuuint64_t *globalStrides, 185 | box_size, // const cuuint32_t *boxDim, 186 | elem_stride, // const cuuint32_t *elementStrides, 187 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 188 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 189 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 190 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 191 | 192 | assert(res == CUDA_SUCCESS); 193 | 194 | CUresult res_tr = cuTensorMapEncodeTiled( 195 | &tensor_map_tr, // CUtensorMap *tensorMap, 196 | CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 197 | rank, // cuuint32_t tensorRank, 198 | tensor_ptr_tr, // void *globalAddress, 199 | size_tr, // const cuuint64_t *globalDim, 200 | stride, // const cuuint64_t *globalStrides, 201 | box_size_tr, // const cuuint32_t *boxDim, 202 | elem_stride, // const cuuint32_t *elementStrides, 203 | CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, 204 | CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, 205 | CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, 206 | CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); 207 | 208 | assert(res_tr == CUDA_SUCCESS); 209 | 210 | dim3 blockDim((SMEM_WIDTH * SMEM_HEIGHT) / BATCH_SIZE, 1, 1); 211 | dim3 gridDim(GMEM_WIDTH / SMEM_WIDTH, GMEM_HEIGHT / SMEM_HEIGHT, 1); 212 | 213 | kernel 214 | <<>>(tensor_map, tensor_map_tr); 215 | 216 | CHECK_LAST_CUDA_ERROR(); 217 | CHECK_CUDA_ERROR(cudaMemcpy(h_out, d_tr, SIZE, cudaMemcpyDeviceToHost)); 218 | 219 | const float epsilon = 1e-5f; 220 | for (int x = 0; x < GMEM_HEIGHT; x++) { 221 | for (int y = 0; y < GMEM_WIDTH; y++) { 222 | float expected = h_in[x * GMEM_WIDTH + y]; 223 | float actual = h_out[y * GMEM_HEIGHT + x]; 224 | if (std::fabs(expected - actual) > epsilon) { 225 | std::cout << "Error at position (" << x << "," << y << "): expected " 226 | << expected << " but got " << actual << std::endl; 227 | return -1; 228 | } 229 | } 230 | } 231 | 232 | std::cout << "Passed" << std::endl; 233 | 234 | CHECK_CUDA_ERROR(cudaFree(d)); 235 | free(h_in); 236 | free(h_out); 237 | } -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace utils { 10 | 11 | size_t number_of_digits(double n) { 12 | std::ostringstream strs; 13 | strs << n; 14 | return strs.str().size(); 15 | } 16 | 17 | // Generate a gradient of colors from green to red with larger transitions 18 | std::vector generateColorGradient(size_t num_colors) { 19 | std::vector colors; 20 | colors.reserve(num_colors); 21 | 22 | // Define distinct color points in the gradient with larger transitions 23 | struct ColorPoint { 24 | int r, g, b; 25 | float position; // 0.0 to 1.0 26 | }; 27 | 28 | std::vector color_points = { 29 | {0, 128, 0, 0.0f}, // Dark Green 30 | {0, 255, 0, 0.25f}, // Bright Green 31 | {255, 255, 0, 0.5f}, // Yellow 32 | {255, 128, 0, 0.75f}, // Orange 33 | {255, 0, 0, 1.0f} // Bright Red 34 | }; 35 | 36 | for (size_t i = 0; i < num_colors; ++i) { 37 | float t = static_cast(i) / (num_colors - 1); 38 | 39 | // Find the two color points to interpolate between 40 | ColorPoint* lower = &color_points[0]; 41 | ColorPoint* upper = &color_points[color_points.size() - 1]; 42 | 43 | for (size_t j = 0; j < color_points.size() - 1; ++j) { 44 | if (t >= color_points[j].position && t <= color_points[j + 1].position) { 45 | lower = &color_points[j]; 46 | upper = &color_points[j + 1]; 47 | break; 48 | } 49 | } 50 | 51 | // Interpolate between the two color points 52 | float local_t = (t - lower->position) / (upper->position - lower->position); 53 | int r = static_cast(lower->r + local_t * (upper->r - lower->r)); 54 | int g = static_cast(lower->g + local_t * (upper->g - lower->g)); 55 | int b = static_cast(lower->b + local_t * (upper->b - lower->b)); 56 | 57 | // Convert RGB to 256-color terminal code with more distinct colors 58 | int color_code = 16 + (r / 51) * 36 + (g / 51) * 6 + (b / 51); 59 | colors.push_back("\033[48;5;" + std::to_string(color_code) + "m"); 60 | } 61 | 62 | return colors; 63 | } 64 | 65 | // Generate a fixed gradient for values 0-31 66 | std::vector generateFixedGradient() { 67 | std::vector colors; 68 | colors.reserve(32); 69 | 70 | // Define distinct color points in the gradient 71 | struct ColorPoint { 72 | int r, g, b; 73 | float position; // 0.0 to 1.0 74 | }; 75 | 76 | std::vector color_points = { 77 | {0, 128, 0, 0.0f}, // Dark Green 78 | {0, 255, 0, 0.25f}, // Bright Green 79 | {255, 255, 0, 0.5f}, // Yellow 80 | {255, 128, 0, 0.75f}, // Orange 81 | {255, 0, 0, 1.0f} // Bright Red 82 | }; 83 | 84 | for (size_t i = 0; i < 32; ++i) { 85 | float t = static_cast(i) / 31.0f; 86 | 87 | // Find the two color points to interpolate between 88 | ColorPoint* lower = &color_points[0]; 89 | ColorPoint* upper = &color_points[color_points.size() - 1]; 90 | 91 | for (size_t j = 0; j < color_points.size() - 1; ++j) { 92 | if (t >= color_points[j].position && t <= color_points[j + 1].position) { 93 | lower = &color_points[j]; 94 | upper = &color_points[j + 1]; 95 | break; 96 | } 97 | } 98 | 99 | // Interpolate between the two color points 100 | float local_t = (t - lower->position) / (upper->position - lower->position); 101 | int r = static_cast(lower->r + local_t * (upper->r - lower->r)); 102 | int g = static_cast(lower->g + local_t * (upper->g - lower->g)); 103 | int b = static_cast(lower->b + local_t * (upper->b - lower->b)); 104 | 105 | // Convert RGB to 256-color terminal code 106 | int color_code = 16 + (r / 51) * 36 + (g / 51) * 6 + (b / 51); 107 | colors.push_back("\033[48;5;" + std::to_string(color_code) + "m"); 108 | } 109 | 110 | return colors; 111 | } 112 | 113 | template 114 | void printMatrix(const double matrix[N][M], size_t n, size_t m) { 115 | size_t max_len_per_column[M] = {0}; 116 | 117 | // Find maximum length for each column 118 | for (size_t j = 0; j < m; ++j) { 119 | size_t max_len = 0; 120 | for (size_t i = 0; i < n; ++i) { 121 | if (const auto num_length = number_of_digits(matrix[i][j]); 122 | num_length > max_len) { 123 | max_len = num_length; 124 | } 125 | } 126 | max_len_per_column[j] = max_len; 127 | } 128 | 129 | // Print the matrix 130 | for (size_t i = 0; i < n; ++i) { 131 | for (size_t j = 0; j < m; ++j) { 132 | std::cout << (j == 0 ? "\n| " : "") << std::setw(max_len_per_column[j]) 133 | << matrix[i][j] << (j == m - 1 ? " |" : " "); 134 | } 135 | } 136 | std::cout << '\n'; 137 | } 138 | 139 | // Print matrix as a heatmap with gradient colors scaled to block size 140 | void printMatrixHeatmap(const int* matrix, size_t height, size_t width, 141 | size_t block_size) { 142 | const std::string reset_color = "\033[0m"; 143 | auto colors = generateColorGradient(block_size * block_size); 144 | 145 | // Print the heatmap 146 | for (size_t i = 0; i < height; ++i) { 147 | std::cout << "\n| "; 148 | for (size_t j = 0; j < width; ++j) { 149 | int value = matrix[i * width + j]; 150 | // Scale the value to the block's color range (0 to block_size^2 - 1) 151 | size_t color_idx = static_cast(value) % colors.size(); 152 | std::cout << colors[color_idx] << std::setw(2) << value << reset_color 153 | << " "; 154 | } 155 | std::cout << "|"; 156 | } 157 | std::cout << "\n\nColor Scale: Green -> Yellow -> Orange -> Red\n"; 158 | std::cout << "Colors scaled to block size: " << block_size << "x" 159 | << block_size << "\n"; 160 | } 161 | 162 | // Print matrix as a heatmap with fixed 0-31 color range 163 | void printMatrixHeatmap32(const int* matrix, size_t height, size_t width) { 164 | const std::string reset_color = "\033[0m"; 165 | static auto colors = generateFixedGradient(); // Generate once and reuse 166 | 167 | // Print the heatmap 168 | for (size_t i = 0; i < height; ++i) { 169 | std::cout << "\n| "; 170 | for (size_t j = 0; j < width; ++j) { 171 | int value = matrix[i * width + j]; 172 | // Direct mapping to 0-31 range 173 | size_t color_idx = static_cast(value) % 32; 174 | std::cout << colors[color_idx] << std::setw(2) << value << reset_color 175 | << " "; 176 | } 177 | std::cout << "|"; 178 | } 179 | std::cout << "\n\nColor Scale: Green -> Yellow -> Orange -> Red\n"; 180 | std::cout << "Values range: 0-31\n"; 181 | } 182 | 183 | // Overload for 1D arrays treated as 2D matrices 184 | void printMatrix(const int* matrix, size_t height, size_t width) { 185 | size_t max_len_per_column[width] = {0}; 186 | 187 | // Find maximum length for each column 188 | for (size_t j = 0; j < width; ++j) { 189 | size_t max_len = 0; 190 | for (size_t i = 0; i < height; ++i) { 191 | if (const auto num_length = number_of_digits(matrix[i * width + j]); 192 | num_length > max_len) { 193 | max_len = num_length; 194 | } 195 | } 196 | max_len_per_column[j] = max_len; 197 | } 198 | 199 | // Print the matrix 200 | for (size_t i = 0; i < height; ++i) { 201 | for (size_t j = 0; j < width; ++j) { 202 | std::cout << (j == 0 ? "\n| " : "") << std::setw(max_len_per_column[j]) 203 | << matrix[i * width + j] << (j == width - 1 ? " |" : " "); 204 | } 205 | } 206 | std::cout << '\n'; 207 | } 208 | 209 | } // namespace utils --------------------------------------------------------------------------------