├── .gitignore ├── run.sh ├── README.md ├── src ├── utils.cuh ├── sgemm │ ├── sgemm.cu │ ├── sgemm.cuh │ ├── shared_memory.cu │ ├── float4.cu │ ├── double_buffering.cu │ ├── tile.cu │ └── split_tile.cu └── utils.cu ├── LICENSE ├── CMakeLists.txt └── main.cu /.gitignore: -------------------------------------------------------------------------------- 1 | cmake-build-debug 2 | .idea 3 | *.txt 4 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | for ((i=256; i <= 6400; i+=256)) 2 | do 3 | ./sgemm 10 $i 4 | done 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cuda-sgemm-optimization 2 | CUDA SGEMM optimization note 3 | 4 | You can see details on [my blog](http://linn-ylz.com/Computer-Science/CUDA/CUDA-SGEMM-optimization-notes/). 5 | -------------------------------------------------------------------------------- /src/utils.cuh: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 9/29/23. 3 | // 4 | 5 | #ifndef SGEMM_UTILS_CUH 6 | #define SGEMM_UTILS_CUH 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | // https://leimao.github.io/blog/Proper-CUDA-Error-Checking/ 14 | void checkLastError(const char* const file, const int line); 15 | void checkError(cudaError_t err, const char* const func, const char* const file, const int line); 16 | bool verify_matrix(const float *mat1, const float *mat2, int N); 17 | void copy_matrix(const float *src, float *dest, int N); 18 | void print_matrix(const float *A, int M, int N); 19 | void randomize_matrix(float *mat, int N); 20 | void CudaDeviceInfo(); 21 | 22 | #define CHECK_CUDA_ERROR(val) checkError((val), #val, __FILE__, __LINE__) 23 | #define CHECK_LAST_CUDA_ERROR() checkLastError(__FILE__, __LINE__) 24 | 25 | #endif //SGEMM_UTILS_CUH 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yanglinzhuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/sgemm/sgemm.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 9/29/23. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | #include 7 | 8 | // A: (M, K) 9 | // B: (K, N) 10 | // C: (M, N) 11 | void test_cublas(cublasHandle_t handle, int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 12 | //cublas列主序计算:https://www.cnblogs.com/cuancuancuanhao/p/7763256.html 13 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, B, N, A, K, &beta, C, N); 14 | } 15 | 16 | 17 | __global__ void naive_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 18 | int col = blockIdx.x * blockDim.x + threadIdx.x; 19 | int row = blockIdx.y * blockDim.y + threadIdx.y; 20 | if (row < M && col < N) { 21 | float val = 0.; 22 | for (int k = 0; k < K; ++k) { 23 | val += A[OFFSET(row, k, K)] * B[OFFSET(k, col, N)]; 24 | } 25 | C[OFFSET(row, col, N)] = alpha * val + beta * C[OFFSET(row, col, N)]; 26 | } 27 | } 28 | 29 | void test_naive_kernel(cublasHandle_t handle, int M, int N, int K, 30 | float alpha, float *A, float *B, float beta, float *C) { 31 | const int size = 16; 32 | dim3 block(size, size); 33 | dim3 grid(CEIL_DIV(N, size), CEIL_DIV(M, size)); // note: change M and N here 34 | naive_kernel<<>>(M, N, K, alpha, A, B, beta, C); 35 | } 36 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | project(sgemm LANGUAGES CXX CUDA) 3 | 4 | set(CMAKE_CUDA_STANDARD 20) 5 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -lineinfo") 6 | #set(CMAKE_CUDA_FLAGS "-g -G") 7 | 8 | find_package(CUDAToolkit REQUIRED) 9 | 10 | add_executable(${PROJECT_NAME} main.cu 11 | src/utils.cu 12 | src/sgemm/sgemm.cu 13 | src/sgemm/shared_memory.cu 14 | src/sgemm/tile.cu 15 | src/sgemm/float4.cu 16 | src/sgemm/double_buffering.cu 17 | src/sgemm/split_tile.cu) 18 | 19 | # 可执行文件输出路径 20 | # https://gist.github.com/gavinb/c993f71cf33d2354515c4452a3f8ef30 21 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}) 22 | 23 | set_target_properties(${PROJECT_NAME} PROPERTIES 24 | CUDA_SEPARABLE_COMPILATION ON) 25 | # 查询 compute capability https://developer.nvidia.com/cuda-gpus 26 | set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "86") 27 | 28 | # 配置头文件搜索路径 29 | # 配置 CUDA 相关库头文件 30 | # 参考 31 | # https://stackoverflow.com/questions/51756562/obtaining-the-cuda-include-dir-in-c-targets-with-native-cuda-support-cmake 32 | target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 33 | target_include_directories(${PROJECT_NAME} PUBLIC ${PROJECT_SOURCE_DIR}/src) 34 | 35 | # link cudart cublas 36 | target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cublas) 37 | -------------------------------------------------------------------------------- /src/sgemm/sgemm.cuh: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 9/29/23. 3 | // 4 | 5 | #ifndef SGEMM_SGEMM_CUH 6 | #define SGEMM_SGEMM_CUH 7 | 8 | #include 9 | #include 10 | 11 | #define OFFSET(row, col, stride) ((row) * (stride) + (col)) 12 | #define CEIL_DIV(M, N) (((M) + (N - 1)) / (N)) 13 | #define FETCH_FLOAT4(pointer) (reinterpret_cast(&(pointer))[0]) 14 | 15 | void test_cublas(cublasHandle_t handle, int M, int N, int K, 16 | float alpha, float *A, float *B, float beta, float *C); 17 | void test_naive_kernel(cublasHandle_t handle, int M, int N, int K, 18 | float alpha, float *A, float *B, float beta, float *C); 19 | void test_sm_kernel(cublasHandle_t handle, int M, int N, int K, 20 | float alpha, float *A, float *B, float beta, float *C); 21 | void test_tile_1d_kernel(cublasHandle_t handle, int M, int N, int K, 22 | float alpha, float *A, float *B, float beta, float *C); 23 | void test_tile_2d_kernel(cublasHandle_t handle, int M, int N, int K, 24 | float alpha, float *A, float *B, float beta, float *C); 25 | void test_tile_2d_reg_cache_kernel(cublasHandle_t handle, int M, int N, int K, 26 | float alpha, float *A, float *B, float beta, float *C); 27 | void test_tile_2d_float4_kernel(cublasHandle_t handle, int M, int N, int K, 28 | float alpha, float *A, float *B, float beta, float *C); 29 | void test_tile_2d_float4_double_buffering_kernel(cublasHandle_t handle, int M, int N, int K, 30 | float alpha, float *A, float *B, float beta, float *C); 31 | void test_no_share_conflict_kernel(cublasHandle_t handle, int M, int N, int K, 32 | float alpha, float *A, float *B, float beta, float *C); 33 | void test_tile_2d_split_kernel(cublasHandle_t handle, int M, int N, int K, 34 | float alpha, float *A, float *B, float beta, float *C); 35 | void test_tile_1d_split_kernel(cublasHandle_t handle, int M, int N, int K, 36 | float alpha, float *A, float *B, float beta, float *C); 37 | #endif //SGEMM_SGEMM_CUH 38 | -------------------------------------------------------------------------------- /src/sgemm/shared_memory.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by YangLinzhuo on 2023/10/12. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | 7 | template 8 | __global__ void sm_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 9 | __shared__ float As[BM][BK]; 10 | __shared__ float Bs[BK][BN]; 11 | float val = 0.; 12 | 13 | for (int i = 0; i < CEIL_DIV(K, BK); ++i) { 14 | // Copy data from global memory to shared memory 15 | int A_row = blockIdx.y * BM + threadIdx.y; 16 | int A_col = i * BK + threadIdx.x; 17 | if (A_row < M && A_col < K) { 18 | As[threadIdx.y][threadIdx.x] = A[OFFSET(A_row, A_col, K)]; 19 | } else { 20 | As[threadIdx.y][threadIdx.x] = 0.; 21 | } 22 | int B_row = i * BK + threadIdx.y; 23 | int B_col = blockIdx.x * BN + threadIdx.x; 24 | if (B_row < K && B_col < N) { 25 | Bs[threadIdx.y][threadIdx.x] = B[OFFSET(B_row, B_col, N)]; 26 | } else { 27 | Bs[threadIdx.y][threadIdx.x] = 0.; 28 | } 29 | __syncthreads(); 30 | 31 | #pragma unroll 32 | for (int k = 0; k < BK; ++k) { 33 | val += As[threadIdx.y][k] * Bs[k][threadIdx.x]; 34 | } 35 | __syncthreads(); 36 | } 37 | int C_row = blockIdx.y * BM + threadIdx.y; 38 | int C_col = blockIdx.x * BN + threadIdx.x; 39 | if (C_row < M && C_col < N) { 40 | C[OFFSET(C_row, C_col, N)] = alpha * val + beta * C[OFFSET(C_row, C_col, N)]; 41 | } 42 | } 43 | 44 | template 45 | __global__ void sm_transposed_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 46 | __shared__ float As[BM][BK]; 47 | __shared__ float Bs[BK][BN]; 48 | float val = 0.; 49 | 50 | for (int i = 0; i < CEIL_DIV(K, BK); ++i) { 51 | // Copy data from global memory to shared memory 52 | int A_row = blockIdx.y * BM + threadIdx.y; 53 | int A_col = i * BK + threadIdx.x; 54 | if (A_row < M && A_col < K) { 55 | As[threadIdx.y][threadIdx.x] = A[OFFSET(A_row, A_col, K)]; 56 | } else { 57 | As[threadIdx.y][threadIdx.x] = 0.; 58 | } 59 | int B_row = i * BK + threadIdx.y; 60 | int B_col = blockIdx.x * BN + threadIdx.x; 61 | if (B_row < K && B_col < N) { 62 | Bs[threadIdx.y][threadIdx.x] = B[OFFSET(B_row, B_col, N)]; 63 | } else { 64 | Bs[threadIdx.y][threadIdx.x] = 0.; 65 | } 66 | __syncthreads(); 67 | 68 | #pragma unroll 69 | for (int k = 0; k < BK; ++k) { 70 | val += As[threadIdx.y][k] * Bs[k][threadIdx.x]; 71 | } 72 | __syncthreads(); 73 | } 74 | int C_row = blockIdx.y * BM + threadIdx.y; 75 | int C_col = blockIdx.x * BN + threadIdx.x; 76 | if (C_row < M && C_col < N) { 77 | C[OFFSET(C_row, C_col, N)] = alpha * val + beta * C[OFFSET(C_row, C_col, N)]; 78 | } 79 | } 80 | 81 | 82 | void test_sm_kernel(cublasHandle_t handle, int M, int N, int K, 83 | float alpha, float *A, float *B, float beta, float *C) { 84 | const int size = 16; 85 | dim3 block(size, size); 86 | dim3 grid(CEIL_DIV(N, size), CEIL_DIV(M, size)); // note: change M and N here 87 | sm_kernel<<>>(M, N, K, alpha, A, B, beta, C); 88 | // sm_transposed_kernel<<>>(M, N, K, alpha, A, B, beta, C); 89 | } -------------------------------------------------------------------------------- /src/utils.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 9/29/23. 3 | // 4 | 5 | #include 6 | #include 7 | #include "utils.cuh" 8 | 9 | void checkError(cudaError_t err, const char* const func, const char* const file, 10 | const int line) 11 | { 12 | if (err != cudaSuccess) 13 | { 14 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 15 | std::cerr << "[Error]" << cudaGetErrorString(err) << " when call " << func << std::endl; 16 | std::exit(EXIT_FAILURE); 17 | } 18 | } 19 | 20 | void checkLastError(const char* const file, const int line) 21 | { 22 | cudaError_t err{cudaGetLastError()}; 23 | if (err != cudaSuccess) 24 | { 25 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; 26 | std::cerr << cudaGetErrorString(err) << std::endl; 27 | std::exit(EXIT_FAILURE); 28 | } 29 | } 30 | 31 | 32 | bool verify_matrix(const float *mat1, const float *mat2, int N) { 33 | double diff = 0.0; 34 | int i; 35 | for (i = 0; i < N; i++) { 36 | diff = fabs((double) mat1[i] - (double) mat2[i]); 37 | if (diff > 1e-1) { 38 | printf("error. %5.2f,%5.2f,%d\n", mat1[i], mat2[i], i); 39 | return false; 40 | } 41 | } 42 | return true; 43 | } 44 | 45 | 46 | void randomize_matrix(float *mat, int N) { 47 | auto const seed = 1101; 48 | std::mt19937 engine {seed}; 49 | std::uniform_real_distribution generator {-5.f, 5.f}; 50 | for (int i = 0; i < N; i++) { 51 | mat[i] = generator(engine); 52 | } 53 | } 54 | 55 | 56 | void copy_matrix(const float *src, float *dest, int N) { 57 | int i; 58 | for (i = 0; src + i && dest + i && i < N; i++) 59 | *(dest + i) = *(src + i); 60 | if (i != N) 61 | printf("copy failed at %d while there are %d elements in total.\n", i, N); 62 | } 63 | 64 | 65 | void print_matrix(const float *A, int M, int N) { 66 | printf("[\n"); 67 | for (int i = 0; i < M; ++i) { 68 | for (int j = 0; j < N; ++j) { 69 | printf("%5.2f ", A[i * N + j]); 70 | } 71 | printf("\n"); 72 | } 73 | printf("]\n"); 74 | } 75 | 76 | 77 | void CudaDeviceInfo() { 78 | int deviceId; 79 | 80 | cudaGetDevice(&deviceId); 81 | 82 | cudaDeviceProp props; 83 | cudaGetDeviceProperties(&props, deviceId); 84 | 85 | /* 86 | * There should be no need to modify the output string below. 87 | */ 88 | 89 | printf("Device ID: %d\n\ 90 | *Number of SMs: %d\n\ 91 | Compute Capability Major: %d\n\ 92 | Compute Capability Minor: %d\n\ 93 | memoryBusWidth: %d\n\ 94 | *maxThreadsPerBlock: %d\n\ 95 | maxThreadsPerMultiProcessor: %d\n\ 96 | *totalGlobalMem: %zuM\n\ 97 | sharedMemPerBlock: %zuKB\n\ 98 | *sharedMemPerMultiprocessor: %zuKB\n\ 99 | totalConstMem: %zuKB\n\ 100 | *multiProcessorCount: %d\n\ 101 | *Warp Size: %d\n", 102 | deviceId, 103 | props.multiProcessorCount, 104 | props.major, 105 | props.minor, 106 | props.memoryBusWidth, 107 | props.maxThreadsPerBlock, 108 | props.maxThreadsPerMultiProcessor, 109 | props.totalGlobalMem / 1024 / 1024, 110 | props.sharedMemPerBlock / 1024, 111 | props.sharedMemPerMultiprocessor / 1024, 112 | props.totalConstMem / 1024, 113 | props.multiProcessorCount, 114 | props.warpSize); 115 | }; -------------------------------------------------------------------------------- /src/sgemm/float4.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 10/14/23. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | 7 | template 8 | __global__ void tile_2d_float4_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 9 | const int block_row_thread = BN / TN; 10 | const int block_col_thread = BM / TM; 11 | const int thread_num = block_row_thread * block_col_thread; 12 | int num_shared_block = CEIL_DIV(K, BK); 13 | 14 | __shared__ float As[BK][BM]; // transpose shared A for avoid bank conflict 15 | __shared__ float Bs[BK][BN]; 16 | 17 | float accum[TM][TN] = {0.}; 18 | 19 | const int load_a_cache_time = (BK * BM) / thread_num / 4; // Each thread load 4 float 20 | const int load_b_cache_time = (BK * BN) / thread_num / 4; // Each thread load 4 float 21 | 22 | float load_a_cache[4 * load_a_cache_time]; 23 | 24 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; // Set block start position 25 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 26 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 27 | 28 | int thread_id = threadIdx.y * blockDim.x + threadIdx.x; 29 | int a_tile_row = thread_id / (BK / 4); 30 | int a_tile_col = thread_id % (BK / 4) * 4; 31 | int a_tile_stride = BM / load_a_cache_time; 32 | // printf("A tile row, col, stride %d, %d, %d", a_tile_row, a_tile_col, a_tile_stride); 33 | 34 | int b_tile_row = thread_id / (BN / 4); 35 | int b_tile_col = thread_id % (BN / 4) * 4; 36 | int b_tile_stride = BK / load_b_cache_time; 37 | 38 | float As_cache[TM] = {0.}; 39 | float Bs_cache[TN] = {0.}; 40 | 41 | #pragma unroll 42 | for (int i = 0; i < num_shared_block; ++i) { 43 | #pragma unroll 44 | for (int m = 0; m < BM; m += a_tile_stride) { 45 | int cache_idx = m / a_tile_stride * 4; 46 | FETCH_FLOAT4(load_a_cache[cache_idx]) = 47 | FETCH_FLOAT4(A[OFFSET(a_tile_row + m, a_tile_col, K)]); 48 | // Use load_a_cache for load 4 float at a time 49 | // As is saved as transpose matrix 50 | As[a_tile_col][a_tile_row + m] = load_a_cache[cache_idx]; 51 | As[a_tile_col + 1][a_tile_row + m] = load_a_cache[cache_idx + 1]; 52 | As[a_tile_col + 2][a_tile_row + m] = load_a_cache[cache_idx + 2]; 53 | As[a_tile_col + 3][a_tile_row + m] = load_a_cache[cache_idx + 3]; 54 | } 55 | #pragma unroll 56 | for (int k = 0; k < BK; k += b_tile_stride) { 57 | FETCH_FLOAT4(Bs[b_tile_row + k][b_tile_col]) = 58 | FETCH_FLOAT4(B[OFFSET(b_tile_row + k, b_tile_col, N)]); 59 | } 60 | __syncthreads(); 61 | A += BK; // Start position of next tile block to be processed 62 | B += BK * N; // Start position of next tile block to be processed 63 | 64 | #pragma unroll 65 | for (int k = 0; k < BK; ++k) { 66 | #pragma unroll 67 | for (int m = 0; m < TM; m += 4) { 68 | int A_row = threadIdx.y * TM + m; 69 | FETCH_FLOAT4(As_cache[m]) = FETCH_FLOAT4(As[k][A_row]); 70 | } 71 | #pragma unroll 72 | for (int n = 0; n < TN; n += 4) { 73 | int B_col = threadIdx.x * TN + n; 74 | FETCH_FLOAT4(Bs_cache[n]) = FETCH_FLOAT4(Bs[k][B_col]); 75 | } 76 | #pragma unroll 77 | for (int m = 0; m < TM; ++m) { 78 | #pragma unroll 79 | for (int n = 0; n < TN; ++n) { 80 | accum[m][n] += As_cache[m] * Bs_cache[n]; 81 | } 82 | } 83 | } 84 | __syncthreads(); 85 | } 86 | 87 | float tmp[4] = {0.}; 88 | #pragma unroll 89 | for (int m = 0; m < TM; ++m) { 90 | int C_row = threadIdx.y * TM + m; 91 | #pragma unroll 92 | for (int n = 0; n < TN; n += 4) { 93 | int C_col = threadIdx.x * TN + n; 94 | FETCH_FLOAT4(tmp) = FETCH_FLOAT4(C[OFFSET(C_row, C_col, N)]); 95 | tmp[0] = alpha * accum[m][n] + beta * tmp[0]; 96 | tmp[1] = alpha * accum[m][n + 1] + beta * tmp[1]; 97 | tmp[2] = alpha * accum[m][n + 2] + beta * tmp[2]; 98 | tmp[3] = alpha * accum[m][n + 3] + beta * tmp[3]; 99 | FETCH_FLOAT4(C[OFFSET(C_row, C_col, N)]) = FETCH_FLOAT4(tmp); 100 | } 101 | } 102 | } 103 | 104 | void test_tile_2d_float4_kernel(cublasHandle_t handle, int M, int N, int K, 105 | float alpha, float *A, float *B, float beta, float *C) { 106 | const int size = 16; 107 | const int tile_size = 8; 108 | const int BM = size * tile_size; 109 | const int BN = size * tile_size; 110 | const int BK = 8; 111 | const int TM = 8; 112 | const int TN = 8; 113 | dim3 block(size, size); 114 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 115 | tile_2d_float4_kernel<<>>(M, N, K, alpha, A, B, beta, C); 116 | } 117 | 118 | -------------------------------------------------------------------------------- /src/sgemm/double_buffering.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 10/14/23. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | 7 | template 8 | __global__ void tile_2d_float4_double_buffering_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 9 | const int block_row_thread = BN / TN; 10 | const int block_col_thread = BM / TM; 11 | const int thread_num = block_row_thread * block_col_thread; 12 | int num_shared_block = CEIL_DIV(K, BK); 13 | 14 | __shared__ float As[2][BK][BM]; // transpose shared A for avoid bank conflict, for double buffering 15 | __shared__ float Bs[2][BK][BN]; // for double buffering 16 | 17 | float accum[TM][TN] = {0.}; 18 | 19 | const int load_a_cache_time = (BK * BM) / thread_num / 4; // Each thread load 4 float 20 | const int load_b_cache_time = (BK * BN) / thread_num / 4; // Each thread load 4 float 21 | 22 | float load_a_cache[4 * load_a_cache_time]; 23 | // float load_a_cache[4]; 24 | // float load_b_cache[4 * load_b_cache_time]; 25 | 26 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; // Set block start position 27 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 28 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 29 | 30 | int thread_id = threadIdx.y * blockDim.x + threadIdx.x; 31 | int a_tile_row = thread_id / (BK / 4); 32 | int a_tile_col = thread_id % (BK / 4) * 4; 33 | int a_tile_stride = BM / load_a_cache_time; 34 | // printf("A tile row, col, stride %d, %d, %d", a_tile_row, a_tile_col, a_tile_stride); 35 | 36 | int b_tile_row = thread_id / (BN / 4); 37 | int b_tile_col = thread_id % (BN / 4) * 4; 38 | int b_tile_stride = BK / load_b_cache_time; 39 | 40 | float As_cache[2][TM] = {0.}; // double buffering 41 | float Bs_cache[2][TN] = {0.}; // double buffering 42 | 43 | int write_idx = 0; 44 | 45 | #pragma unroll 46 | for (int i = 0; i < num_shared_block; ++i) { 47 | #pragma unroll 48 | for (int m = 0; m < BM; m += a_tile_stride) { 49 | int cache_idx = m / a_tile_stride * 4; 50 | FETCH_FLOAT4(load_a_cache[cache_idx]) = 51 | FETCH_FLOAT4(A[OFFSET(a_tile_row + m, a_tile_col, K)]); 52 | // Use load_a_cache for load 4 float at a time 53 | // As is saved as transpose matrix 54 | As[write_idx][a_tile_col][a_tile_row + m] = load_a_cache[cache_idx]; 55 | // 这里 stride = 128,有 shared memory bank 冲突 56 | As[write_idx][a_tile_col + 1][a_tile_row + m] = load_a_cache[cache_idx + 1]; 57 | As[write_idx][a_tile_col + 2][a_tile_row + m] = load_a_cache[cache_idx + 2]; 58 | As[write_idx][a_tile_col + 3][a_tile_row + m] = load_a_cache[cache_idx + 3]; 59 | } 60 | #pragma unroll 61 | for (int k = 0; k < BK; k += b_tile_stride) { 62 | FETCH_FLOAT4(Bs[write_idx][b_tile_row + k][b_tile_col]) = 63 | FETCH_FLOAT4(B[OFFSET(b_tile_row + k, b_tile_col, N)]); 64 | } 65 | __syncthreads(); 66 | A += BK; // Start position of next tile block to be processed 67 | B += BK * N; // Start position of next tile block to be processed 68 | 69 | #pragma unroll 70 | for (int k = 0; k < BK; ++k) { 71 | #pragma unroll 72 | for (int m = 0; m < TM; m += 4) { 73 | int A_row = threadIdx.y * TM + m; 74 | FETCH_FLOAT4(As_cache[write_idx][m]) = FETCH_FLOAT4(As[write_idx][k][A_row]); 75 | } 76 | #pragma unroll 77 | for (int n = 0; n < TN; n += 4) { 78 | int B_col = threadIdx.x * TN + n; 79 | FETCH_FLOAT4(Bs_cache[write_idx][n]) = FETCH_FLOAT4(Bs[write_idx][k][B_col]); 80 | } 81 | #pragma unroll 82 | for (int m = 0; m < TM; ++m) { 83 | #pragma unroll 84 | for (int n = 0; n < TN; ++n) { 85 | accum[m][n] += As_cache[write_idx][m] * Bs_cache[write_idx][n]; 86 | } 87 | } 88 | } 89 | write_idx ^= 1; 90 | } 91 | 92 | #pragma unroll 93 | for (int m = 0; m < TM; ++m) { 94 | int C_row = threadIdx.y * TM + m; 95 | #pragma unroll 96 | for (int n = 0; n < TN; n += 4) { 97 | int C_col = threadIdx.x * TN + n; 98 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(C_row, C_col, N)]); 99 | load_a_cache[0] = alpha * accum[m][n] + beta * load_a_cache[0]; 100 | load_a_cache[1] = alpha * accum[m][n + 1] + beta * load_a_cache[1]; 101 | load_a_cache[2] = alpha * accum[m][n + 2] + beta * load_a_cache[2]; 102 | load_a_cache[3] = alpha * accum[m][n + 3] + beta * load_a_cache[3]; 103 | FETCH_FLOAT4(C[OFFSET(C_row, C_col, N)]) = FETCH_FLOAT4(load_a_cache); 104 | } 105 | } 106 | } 107 | 108 | void test_tile_2d_float4_double_buffering_kernel(cublasHandle_t handle, int M, int N, int K, 109 | float alpha, float *A, float *B, float beta, float *C) { 110 | const int size = 16; 111 | const int tile_size = 8; 112 | const int BM = size * tile_size; 113 | const int BN = size * tile_size; 114 | const int BK = 8; 115 | const int TM = 8; 116 | const int TN = 8; 117 | dim3 block(size, size); 118 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 119 | tile_2d_float4_double_buffering_kernel<<>>(M, N, K, alpha, A, B, beta, C); 120 | } 121 | -------------------------------------------------------------------------------- /main.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 9/29/23. 3 | // 4 | 5 | #include 6 | #include "src/utils.cuh" 7 | #include "src/sgemm/sgemm.cuh" 8 | 9 | int main(int argc, const char* argv[]) { 10 | if (argc == 1) { 11 | CudaDeviceInfo(); 12 | return 0; 13 | } 14 | 15 | if (argc != 3 && argc != 5) { 16 | printf("Please select a kernel and corresponding matrix size.\n"); 17 | printf("Max kernel size is 6400.\n"); 18 | printf("Kernel 0 is for cuBLAS kernel implemented by NVIDIA.\n"); 19 | std::exit(EXIT_FAILURE); 20 | } 21 | 22 | // 申明句柄,创建句柄, cublasCreate会返回一个cublasStatus_t类型的值,用来判断句柄是否创建成功(值为0) 23 | cublasHandle_t handle = nullptr; 24 | if (cublasCreate(&handle)) { 25 | printf("Create cublas handle error.\n"); 26 | std::exit(EXIT_FAILURE); 27 | } 28 | 29 | float elapsed_time = 0.; 30 | cudaEvent_t beg = nullptr, end = nullptr; 31 | cudaEventCreate(&beg); 32 | cudaEventCreate(&end); 33 | 34 | int kernel = atoi(argv[1]); 35 | 36 | using func_ptr = void (*)(cublasHandle_t handle, int M, int N, int K, 37 | float alpha, float *A, float *B, float beta, float *C); 38 | func_ptr test_func = nullptr; 39 | switch (kernel) { 40 | case 1: 41 | test_func = test_naive_kernel; 42 | break; 43 | case 2: 44 | test_func = test_sm_kernel; 45 | break; 46 | case 3: 47 | test_func = test_tile_1d_kernel; 48 | break; 49 | case 4: 50 | test_func = test_tile_2d_kernel; 51 | break; 52 | case 5: 53 | test_func = test_tile_2d_reg_cache_kernel; 54 | break; 55 | case 6: 56 | test_func = test_tile_2d_float4_kernel; 57 | break; 58 | case 7: 59 | test_func = test_tile_2d_float4_double_buffering_kernel; 60 | break; 61 | case 8: 62 | test_func = test_no_share_conflict_kernel; 63 | break; 64 | case 9: 65 | test_func = test_tile_1d_split_kernel; 66 | break; 67 | case 10: 68 | test_func = test_tile_2d_split_kernel; 69 | break; 70 | default: 71 | test_func = test_cublas; 72 | break; 73 | } 74 | 75 | int m = 0, n = 0, k = 0; 76 | if (argc == 3) { 77 | m = n = k = atoi(argv[2]); 78 | } else if (argc == 5) { 79 | m = atoi(argv[2]); 80 | n = atoi(argv[3]); 81 | k = atoi(argv[4]); 82 | } 83 | 84 | float alpha = 1.0, beta = 0.; //two arbitrary input parameters,C=α*AB+β*C 85 | float *A = nullptr, *B = nullptr, *C = nullptr, *C_ref = nullptr; //host matrices 86 | float *dA = nullptr, *dB = nullptr, *dC = nullptr, *dC_ref = nullptr; //device matrices 87 | 88 | A = new float[m * k]; 89 | B = new float[k * n]; 90 | C = new float[m * n]; 91 | C_ref = new float[m * n]; 92 | 93 | CHECK_CUDA_ERROR(cudaMalloc((void **) &dA, sizeof(float) * m * k)); 94 | CHECK_CUDA_ERROR(cudaMalloc((void **) &dB, sizeof(float) * k * n)); 95 | CHECK_CUDA_ERROR(cudaMalloc((void **) &dC, sizeof(float) * m * n)); 96 | CHECK_CUDA_ERROR(cudaMalloc((void **) &dC_ref, sizeof(float) * m * n)); 97 | 98 | randomize_matrix(A, m * k); 99 | randomize_matrix(B, k * n); 100 | randomize_matrix(C, m * n); 101 | copy_matrix(C, C_ref, m * n); 102 | 103 | CHECK_CUDA_ERROR(cudaMemcpy(dA, A, sizeof(float) * m * k, cudaMemcpyHostToDevice)); 104 | CHECK_CUDA_ERROR(cudaMemcpy(dB, B, sizeof(float) * k * n, cudaMemcpyHostToDevice)); 105 | CHECK_CUDA_ERROR(cudaMemcpy(dC, C, sizeof(float) * m * n, cudaMemcpyHostToDevice)); 106 | CHECK_CUDA_ERROR(cudaMemcpy(dC_ref, C_ref, sizeof(float) * m * n, cudaMemcpyHostToDevice)); 107 | 108 | // test_cublas(handle, m, n, k, alpha, dA, dB, beta, dC); 109 | test_func(handle, m, n, k, alpha, dA, dB, beta, dC); 110 | test_cublas(handle, m, n, k, alpha, dA, dB, beta, dC_ref); 111 | CHECK_LAST_CUDA_ERROR(); 112 | cudaDeviceSynchronize(); 113 | CHECK_CUDA_ERROR(cudaMemcpy(C, dC, sizeof(float) * m * n, cudaMemcpyDeviceToHost)); 114 | CHECK_CUDA_ERROR(cudaMemcpy(C_ref, dC_ref, sizeof(float) * m * n, cudaMemcpyDeviceToHost)); 115 | CHECK_LAST_CUDA_ERROR(); 116 | cudaDeviceSynchronize(); 117 | 118 | if (!verify_matrix(C_ref, C, m * n)) { 119 | printf("Failed to pass the correctness verification against NVIDIA cuBLAS. Exited.\n"); 120 | print_matrix(C, m, n); 121 | print_matrix(C_ref, m, n); 122 | std::exit(EXIT_FAILURE); 123 | } 124 | 125 | int repeat_times = 1; 126 | cudaEventRecord(beg); 127 | for (int j = 0; j < repeat_times; j++) { 128 | test_func(handle, m, n, k, alpha, dA, dB, beta, dC); 129 | } 130 | cudaEventRecord(end); 131 | cudaEventSynchronize(beg); 132 | cudaEventSynchronize(end); 133 | cudaEventElapsedTime(&elapsed_time, beg, end); 134 | elapsed_time /= 1000.; //换算成秒 135 | 136 | // printf("Average elapsed time: (%f) second, performance: (%f) GFLOPS. size: (%d).\n", 137 | // elapsed_time / repeat_times, 2. * 1e-9 * repeat_times * m * n * k / elapsed_time, m); 138 | printf("%f %f %d\n", 139 | elapsed_time / repeat_times, 2. * 1e-9 * repeat_times * m * n * k / elapsed_time, m); 140 | 141 | // 释放CPU和GPU空间 142 | delete[] A; 143 | delete[] B; 144 | delete[] C; 145 | delete[] C_ref; 146 | CHECK_CUDA_ERROR(cudaFree(dA)); 147 | CHECK_CUDA_ERROR(cudaFree(dB)); 148 | CHECK_CUDA_ERROR(cudaFree(dC)); 149 | CHECK_CUDA_ERROR(cudaFree(dC_ref)); 150 | CHECK_LAST_CUDA_ERROR(); 151 | return 0; 152 | } -------------------------------------------------------------------------------- /src/sgemm/tile.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 10/14/23. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | 7 | template 8 | __global__ void tile_1d_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 9 | __shared__ float As[BM][BK]; 10 | __shared__ float Bs[BK][BN]; 11 | float val[TM] = {0.}; 12 | int num_shared_block = CEIL_DIV(K, BK); // or CEIL_DIV(K, BN); 13 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; 14 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 15 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 16 | 17 | for (int i = 0; i < num_shared_block; ++i) { 18 | // Copy data from global memory to shared memory 19 | for (int m = 0; m < TM; ++m) { 20 | int A_row = threadIdx.y * TM + m; 21 | int A_col = threadIdx.x; 22 | if ((blockIdx.y * BM + A_row) < M && (i * BK + A_col) < K) { 23 | As[A_row][A_col] = A[OFFSET(A_row, A_col, K)]; 24 | } else { 25 | As[A_row][A_col] = 0.; 26 | } 27 | } 28 | int B_row = threadIdx.y; 29 | int B_col = threadIdx.x; 30 | if ((i * BK + B_row) < K && (blockIdx.x * BN + B_col) < N) { 31 | Bs[B_row][B_col] = B[OFFSET(B_row, B_col, N)]; 32 | } else { 33 | Bs[B_row][B_col] = 0.; 34 | } 35 | __syncthreads(); 36 | A += BK; 37 | B += BK * N; 38 | for (int k = 0; k < BK; ++k) { 39 | for (int m = 0; m < TM; ++m) { 40 | int A_row = threadIdx.y * TM + m; 41 | int B_col = threadIdx.x; 42 | val[m] += As[A_row][k] * Bs[k][B_col]; 43 | } 44 | } 45 | __syncthreads(); 46 | } 47 | 48 | for (int m = 0; m < TM; ++m) { 49 | int C_row = threadIdx.y * TM + m; 50 | int C_col = threadIdx.x; 51 | if ((blockIdx.y * BM + C_row) < M && (blockIdx.x * BN + C_col) < N) { 52 | C[OFFSET(C_row, C_col, N)] = alpha * val[m] + beta * C[OFFSET(C_row, C_col, N)]; 53 | } 54 | } 55 | } 56 | 57 | 58 | 59 | void test_tile_1d_kernel(cublasHandle_t handle, int M, int N, int K, 60 | float alpha, float *A, float *B, float beta, float *C) { 61 | const int size = 16; 62 | const int tile_size = 8; 63 | const int BM = size * tile_size; 64 | const int BN = size; 65 | const int BK = size; 66 | const int TM = tile_size; 67 | dim3 block(size, size); 68 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 69 | tile_1d_kernel<<>>(M, N, K, alpha, A, B, beta, C); 70 | } 71 | 72 | 73 | template 74 | __global__ void tile_2d_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 75 | __shared__ float As[BM][BK]; 76 | __shared__ float Bs[BK][BN]; 77 | float val[TM][TN] = {0.}; 78 | int num_shared_block = CEIL_DIV(K, BK); // or CEIL_DIV(K, BN); 79 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; 80 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 81 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 82 | 83 | for (int i = 0; i < num_shared_block; ++i) { 84 | // Copy data from global memory to shared memory 85 | for (int m = 0; m < TM; ++m) { 86 | int A_row = threadIdx.y * TM + m; 87 | int A_col = threadIdx.x; 88 | if ((blockIdx.y * BM + A_row) < M && (i * BK + A_col) < K) { 89 | As[A_row][A_col] = A[OFFSET(A_row, A_col, K)]; 90 | } else { 91 | As[A_row][A_col] = 0.; 92 | } 93 | } 94 | for (int n = 0; n < TN; ++n) { 95 | int B_row = threadIdx.y; 96 | int B_col = threadIdx.x * TN + n; 97 | if ((i * BK + B_row) < K && (blockIdx.x * BN + B_col) < N) { 98 | Bs[B_row][B_col] = B[OFFSET(B_row, B_col, N)]; 99 | } else { 100 | Bs[B_row][B_col] = 0.; 101 | } 102 | } 103 | __syncthreads(); 104 | A += BK; 105 | B += BK * N; 106 | for (int k = 0; k < BK; ++k) { 107 | for (int m = 0; m < TM; ++m) { 108 | int A_row = threadIdx.y * TM + m; 109 | for (int n = 0; n < TN; ++n) { 110 | int B_col = threadIdx.x * TN + n; 111 | val[m][n] += As[A_row][k] * Bs[k][B_col]; 112 | } 113 | } 114 | } 115 | __syncthreads(); 116 | } 117 | 118 | for (int m = 0; m < TM; ++m) { 119 | int C_row = threadIdx.y * TM + m; 120 | for (int n = 0; n < TN; ++n) { 121 | int C_col = threadIdx.x * TN + n; 122 | if ((blockIdx.y * BM + C_row) < M && (blockIdx.x * BN + C_col) < N) { 123 | C[OFFSET(C_row, C_col, N)] = alpha * val[m][n] + beta * C[OFFSET(C_row, C_col, N)]; 124 | } 125 | } 126 | } 127 | } 128 | 129 | void test_tile_2d_kernel(cublasHandle_t handle, int M, int N, int K, 130 | float alpha, float *A, float *B, float beta, float *C) { 131 | const int size = 16; 132 | const int tile_size = 4; 133 | const int BM = size * tile_size; 134 | const int BN = size * tile_size; 135 | const int BK = size; 136 | const int TM = tile_size; 137 | const int TN = tile_size; 138 | dim3 block(size, size); 139 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 140 | tile_2d_kernel<<>>(M, N, K, alpha, A, B, beta, C); 141 | } 142 | 143 | 144 | template 145 | __global__ void tile_2d_reg_cache_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 146 | __shared__ float As[BM][BK]; 147 | __shared__ float Bs[BK][BN]; 148 | float As_cache[TM] = {0.}; 149 | float Bs_cache[TN] = {0.}; 150 | float val[TM][TN] = {0.}; 151 | 152 | int num_shared_block = CEIL_DIV(K, BK); 153 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; 154 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 155 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 156 | 157 | for (int i = 0; i < num_shared_block; ++i) { 158 | // Copy data from global memory to shared memory 159 | for (int m = 0; m < TM; ++m) { 160 | int A_row = threadIdx.y * TM + m; 161 | int A_col = threadIdx.x; 162 | if ((blockIdx.y * BM + A_row) < M && (i * BK + A_col) < K) { 163 | As[A_row][A_col] = A[OFFSET(A_row, A_col, K)]; 164 | } else { 165 | As[A_row][A_col] = 0.; 166 | } 167 | } 168 | for (int n = 0; n < TN; ++n) { 169 | int B_row = threadIdx.y; 170 | int B_col = threadIdx.x * TN + n; 171 | if ((i * BK + B_row) < K && (blockIdx.x * BN + B_col) < N) { 172 | Bs[B_row][B_col] = B[OFFSET(B_row, B_col, N)]; 173 | } else { 174 | Bs[B_row][B_col] = 0.; 175 | } 176 | } 177 | __syncthreads(); 178 | A += BK; 179 | B += BK * N; 180 | for (int k = 0; k < BK; ++k) { 181 | for (int m = 0; m < TM; ++m) { 182 | int A_row = threadIdx.y * TM + m; 183 | As_cache[m] = As[A_row][k]; 184 | } 185 | for (int n = 0; n < TN; ++n) { 186 | int B_col = threadIdx.x * TN + n; 187 | Bs_cache[n] = Bs[k][B_col]; 188 | } 189 | for (int m = 0; m < TM; ++m) { 190 | for (int n = 0; n < TN; ++n) { 191 | val[m][n] += As_cache[m] * Bs_cache[n]; 192 | } 193 | } 194 | } 195 | __syncthreads(); 196 | } 197 | 198 | for (int m = 0; m < TM; ++m) { 199 | int C_row = threadIdx.y * TM + m; 200 | for (int n = 0; n < TN; ++n) { 201 | int C_col = threadIdx.x * TN + n; 202 | if ((blockIdx.y * BM + C_row) < M && (blockIdx.x * BN + C_col) < N) { 203 | C[OFFSET(C_row, C_col, N)] = alpha * val[m][n] + beta * C[OFFSET(C_row, C_col, N)]; 204 | } 205 | } 206 | } 207 | } 208 | 209 | 210 | void test_tile_2d_reg_cache_kernel(cublasHandle_t handle, int M, int N, int K, 211 | float alpha, float *A, float *B, float beta, float *C) { 212 | const int size = 16; 213 | const int tile_size = 4; 214 | const int BM = size * tile_size; 215 | const int BN = size * tile_size; 216 | const int BK = size; 217 | const int TM = tile_size; 218 | const int TN = tile_size; 219 | dim3 block(size, size); 220 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 221 | tile_2d_reg_cache_kernel<<>>(M, N, K, alpha, A, B, beta, C); 222 | } 223 | -------------------------------------------------------------------------------- /src/sgemm/split_tile.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by linn on 10/14/23. 3 | // 4 | 5 | #include "sgemm.cuh" 6 | 7 | // solve shared memory bank conflict 8 | // https://blog.csdn.net/Bruce_0712/article/details/65447608 9 | // https://blog.csdn.net/sunmc1204953974/article/details/51078818 10 | // x: warp 在执行时以 half-warp 为单位执行,分属于不同 warp 的线程之间不会有冲突 11 | // 执行和调度以warp为单位,存储器访问以half-warp为单位。 12 | 13 | template 14 | __global__ void no_share_conflict_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 15 | const int block_row_thread = BN / TN; 16 | const int block_col_thread = BM / TM; 17 | const int thread_num = block_row_thread * block_col_thread; 18 | const int THREAD_TILE = TM / 4; 19 | // The left and uppermost element position of thread tile in block 20 | int start_col = blockIdx.x * BN; 21 | int start_row = blockIdx.y * BM; 22 | int tx = threadIdx.x * TN; 23 | int ty = threadIdx.y * TM; 24 | 25 | __shared__ float As[2][BK][BM]; // transpose shared A for avoid bank conflict, for double buffering 26 | __shared__ float Bs[2][BK][BN]; // for double buffering 27 | 28 | float accum[TM][TN] = {0.}; 29 | 30 | const int load_a_cache_time = (BK * BM) / thread_num / 4; // Each thread load 4 float 31 | const int load_b_cache_time = (BK * BN) / thread_num / 4; // Each thread load 4 float 32 | 33 | // float load_a_cache[4 * load_a_cache_time]; 34 | float load_a_cache[4]; 35 | // float load_b_cache[4 * load_b_cache_time]; 36 | 37 | A = &A[OFFSET(start_row, 0, K)]; // Set block start position 38 | B = &B[OFFSET(0, start_col, N)]; 39 | C = &C[OFFSET(start_row, start_col, N)]; 40 | 41 | int thread_id = threadIdx.y * blockDim.x + threadIdx.x; 42 | int a_tile_row = thread_id / (BK / 4); 43 | int a_tile_col = thread_id % (BK / 4) * 4; 44 | int a_tile_stride = BM / load_a_cache_time; 45 | // printf("A tile row, col, stride %d, %d, %d", a_tile_row, a_tile_col, a_tile_stride); 46 | 47 | int b_tile_row = thread_id / (BN / 4); 48 | int b_tile_col = thread_id % (BN / 4) * 4; 49 | int b_tile_stride = BK / load_b_cache_time; 50 | 51 | float a_reg[2][TM] = {0.}; // double buffering 52 | float b_reg[2][TN] = {0.}; // double buffering 53 | 54 | int write_idx = 0; 55 | 56 | #pragma unroll 57 | for (int k = 0; k < K; k += BK) { 58 | #pragma unroll 59 | for (int i = 0; i < BM; i += a_tile_stride) { 60 | int cache_idx = i / a_tile_stride * 4; 61 | FETCH_FLOAT4(load_a_cache) = 62 | FETCH_FLOAT4(A[OFFSET(a_tile_row + i, a_tile_col, K)]); 63 | // Use load_a_cache for load 4 float at a time 64 | // As is saved as transpose matrix 65 | As[write_idx][a_tile_col][a_tile_row + i] = load_a_cache[cache_idx]; 66 | // 这里 stride = 128,有 shared memory bank 冲突 67 | As[write_idx][a_tile_col + 1][a_tile_row + i] = load_a_cache[cache_idx + 1]; 68 | As[write_idx][a_tile_col + 2][a_tile_row + i] = load_a_cache[cache_idx + 2]; 69 | As[write_idx][a_tile_col + 3][a_tile_row + i] = load_a_cache[cache_idx + 3]; 70 | } 71 | #pragma unroll 72 | for (int i = 0; i < BK; i += b_tile_stride) { 73 | FETCH_FLOAT4(Bs[write_idx][b_tile_row + i][b_tile_col]) = 74 | FETCH_FLOAT4(B[OFFSET(b_tile_row + i, b_tile_col, N)]); 75 | } 76 | __syncthreads(); 77 | A += BK; // Start position of next tile block to be processed 78 | B += BK * N; // Start position of next tile block to be processed 79 | 80 | #pragma unroll 81 | for (int i = 0; i < BK; ++i) { 82 | #pragma unroll 83 | for (int t = 0; t < THREAD_TILE; ++t) { 84 | FETCH_FLOAT4(a_reg[write_idx][4 * t]) = 85 | FETCH_FLOAT4(As[write_idx][i][ty / THREAD_TILE + t * BM / THREAD_TILE]); 86 | } 87 | #pragma unroll 88 | for (int t = 0; t < THREAD_TILE; ++t) { 89 | FETCH_FLOAT4(b_reg[write_idx][t * 4]) = 90 | FETCH_FLOAT4(Bs[write_idx][i][tx / THREAD_TILE + t * BM / THREAD_TILE]); 91 | } 92 | #pragma unroll 93 | for (int m = 0; m < TM; ++m) { 94 | #pragma unroll 95 | for (int n = 0; n < TN; ++n) { 96 | accum[m][n] += a_reg[write_idx][m] * b_reg[write_idx][n]; 97 | } 98 | } 99 | } 100 | write_idx ^= 1; 101 | } 102 | 103 | #pragma unroll 104 | for (int m = 0; m < TM / 2; ++m) { 105 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(ty / 2 + m, tx / 2, N)]); 106 | load_a_cache[0] = alpha * accum[m][0] + beta * load_a_cache[0]; 107 | load_a_cache[1] = alpha * accum[m][1] + beta * load_a_cache[1]; 108 | load_a_cache[2] = alpha * accum[m][2] + beta * load_a_cache[2]; 109 | load_a_cache[3] = alpha * accum[m][3] + beta * load_a_cache[3]; 110 | FETCH_FLOAT4(C[OFFSET(ty / 2 + m, tx / 2, N)]) = FETCH_FLOAT4(load_a_cache); 111 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(ty / 2 + m, tx / 2, N) + BN / 2]); 112 | load_a_cache[0] = alpha * accum[m][4] + beta * load_a_cache[0]; 113 | load_a_cache[1] = alpha * accum[m][5] + beta * load_a_cache[1]; 114 | load_a_cache[2] = alpha * accum[m][6] + beta * load_a_cache[2]; 115 | load_a_cache[3] = alpha * accum[m][7] + beta * load_a_cache[3]; 116 | FETCH_FLOAT4(C[OFFSET(ty / 2 + m, tx / 2, N) + BN / 2]) = FETCH_FLOAT4(load_a_cache); 117 | } 118 | 119 | #pragma unroll 120 | for (int m = 0; m < TM / 2; ++m) { 121 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(ty / 2 + m + BN / 2, tx / 2, N)]); 122 | load_a_cache[0] = alpha * accum[m + TM / 2][0] + beta * load_a_cache[0]; 123 | load_a_cache[1] = alpha * accum[m + TM / 2][1] + beta * load_a_cache[1]; 124 | load_a_cache[2] = alpha * accum[m + TM / 2][2] + beta * load_a_cache[2]; 125 | load_a_cache[3] = alpha * accum[m + TM / 2][3] + beta * load_a_cache[3]; 126 | FETCH_FLOAT4(C[OFFSET(ty / 2 + m + BN / 2, tx / 2, N)]) = FETCH_FLOAT4(load_a_cache); 127 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(ty / 2 + m + BN / 2, tx / 2, N) + BN / 2]); 128 | load_a_cache[0] = alpha * accum[m + TM / 2][4] + beta * load_a_cache[0]; 129 | load_a_cache[1] = alpha * accum[m + TM / 2][5] + beta * load_a_cache[1]; 130 | load_a_cache[2] = alpha * accum[m + TM / 2][6] + beta * load_a_cache[2]; 131 | load_a_cache[3] = alpha * accum[m + TM / 2][7] + beta * load_a_cache[3]; 132 | FETCH_FLOAT4(C[OFFSET(ty / 2 + m + BN / 2, tx / 2, N) + BN / 2]) = FETCH_FLOAT4(load_a_cache); 133 | } 134 | } 135 | 136 | 137 | void test_no_share_conflict_kernel(cublasHandle_t handle, int M, int N, int K, 138 | float alpha, float *A, float *B, float beta, float *C) { 139 | const int size = 16; 140 | const int tile_size = 8; 141 | const int BM = size * tile_size; 142 | const int BN = size * tile_size; 143 | const int BK = tile_size; 144 | const int TM = tile_size; 145 | const int TN = tile_size; 146 | dim3 block(size, size); 147 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 148 | no_share_conflict_kernel<<>>(M, N, K, alpha, A, B, beta, C); 149 | } 150 | 151 | 152 | template 153 | __global__ void tile_2d_split_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 154 | const int block_row_thread = BN / TN; 155 | const int block_col_thread = BM / TM; 156 | const int thread_num = block_row_thread * block_col_thread; 157 | int num_shared_block = CEIL_DIV(K, BK); 158 | 159 | __shared__ float As[2][BK][BM]; // transpose shared A for avoid bank conflict, for double buffering 160 | __shared__ float Bs[2][BK][BN]; // for double buffering 161 | 162 | float accum[TM][TN] = {0.}; 163 | 164 | const int load_a_cache_time = (BK * BM) / thread_num / 4; // Each thread load 4 float 165 | const int load_b_cache_time = (BK * BN) / thread_num / 4; // Each thread load 4 float 166 | 167 | float load_a_cache[4 * load_a_cache_time]; 168 | // float load_a_cache[4]; 169 | // float load_b_cache[4 * load_b_cache_time]; 170 | 171 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; // Set block start position 172 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 173 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 174 | 175 | int thread_id = threadIdx.y * blockDim.x + threadIdx.x; 176 | int a_tile_row = thread_id / (BK / 4); 177 | int a_tile_col = thread_id % (BK / 4) * 4; 178 | int a_tile_stride = BM / load_a_cache_time; 179 | // printf("A tile row, col, stride %d, %d, %d", a_tile_row, a_tile_col, a_tile_stride); 180 | 181 | int b_tile_row = thread_id / (BN / 4); 182 | int b_tile_col = thread_id % (BN / 4) * 4; 183 | int b_tile_stride = BK / load_b_cache_time; 184 | 185 | float As_cache[2][TM] = {0.}; // double buffering 186 | float Bs_cache[2][TN] = {0.}; // double buffering 187 | 188 | int write_idx = 0; 189 | 190 | #pragma unroll 191 | for (int i = 0; i < num_shared_block; ++i) { 192 | #pragma unroll 193 | for (int m = 0; m < BM; m += a_tile_stride) { 194 | int cache_idx = m / a_tile_stride * 4; 195 | FETCH_FLOAT4(load_a_cache[cache_idx]) = 196 | FETCH_FLOAT4(A[OFFSET(a_tile_row + m, a_tile_col, K)]); 197 | // Use load_a_cache for load 4 float at a time 198 | // As is saved as transpose matrix 199 | As[write_idx][a_tile_col][a_tile_row + m] = load_a_cache[cache_idx]; 200 | // 这里 stride = 128,有 shared memory bank 冲突 201 | As[write_idx][a_tile_col + 1][a_tile_row + m] = load_a_cache[cache_idx + 1]; 202 | As[write_idx][a_tile_col + 2][a_tile_row + m] = load_a_cache[cache_idx + 2]; 203 | As[write_idx][a_tile_col + 3][a_tile_row + m] = load_a_cache[cache_idx + 3]; 204 | } 205 | #pragma unroll 206 | for (int k = 0; k < BK; k += b_tile_stride) { 207 | FETCH_FLOAT4(Bs[write_idx][b_tile_row + k][b_tile_col]) = 208 | FETCH_FLOAT4(B[OFFSET(b_tile_row + k, b_tile_col, N)]); 209 | } 210 | __syncthreads(); 211 | A += BK; // Start position of next tile block to be processed 212 | B += BK * N; // Start position of next tile block to be processed 213 | 214 | #pragma unroll 215 | for (int k = 0; k < BK; ++k) { 216 | #pragma unroll 217 | // for (int m = 0; m < TM; m += 4) { 218 | for (int m = 0, mm = 0; m < BM && mm < TM; m += block_row_thread * 4, mm += 4) { 219 | // int A_row = threadIdx.y * TM + m; 220 | int A_row = m + threadIdx.y * 4; 221 | FETCH_FLOAT4(As_cache[write_idx][mm]) = FETCH_FLOAT4(As[write_idx][k][A_row]); 222 | } 223 | #pragma unroll 224 | // for (int n = 0; n < TN; n += 4) { 225 | for (int n = 0, nn = 0; n < BN && nn < TN; n += block_col_thread * 4, nn += 4) { 226 | // int B_col = threadIdx.x * TN + n; 227 | int B_col = n + threadIdx.x * 4; 228 | FETCH_FLOAT4(Bs_cache[write_idx][nn]) = FETCH_FLOAT4(Bs[write_idx][k][B_col]); 229 | } 230 | #pragma unroll 231 | for (int m = 0; m < TM; ++m) { 232 | #pragma unroll 233 | for (int n = 0; n < TN; ++n) { 234 | accum[m][n] += As_cache[write_idx][m] * Bs_cache[write_idx][n]; 235 | } 236 | } 237 | } 238 | write_idx ^= 1; 239 | } 240 | 241 | #pragma unroll 242 | for (int m = 0; m < TM; m += 4) { 243 | int C_row = (m / 4) * (block_row_thread * 4) + threadIdx.y * 4; 244 | #pragma unroll 245 | for (int n = 0; n < TN; n += 4) { 246 | int C_col = (n / 4) * (block_col_thread * 4) + threadIdx.x * 4; 247 | #pragma unroll 248 | for (int i = 0; i < 4; ++i) { 249 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(C_row + i, C_col, N)]); 250 | load_a_cache[0] = alpha * accum[m + i][n] + beta * load_a_cache[0]; 251 | load_a_cache[1] = alpha * accum[m + i][n + 1] + beta * load_a_cache[1]; 252 | load_a_cache[2] = alpha * accum[m + i][n + 2] + beta * load_a_cache[2]; 253 | load_a_cache[3] = alpha * accum[m + i][n + 3] + beta * load_a_cache[3]; 254 | FETCH_FLOAT4(C[OFFSET(C_row + i, C_col, N)]) = FETCH_FLOAT4(load_a_cache); 255 | } 256 | } 257 | } 258 | } 259 | 260 | void test_tile_2d_split_kernel(cublasHandle_t handle, int M, int N, int K, 261 | float alpha, float *A, float *B, float beta, float *C) { 262 | const int size = 16; 263 | const int tile_size = 8; 264 | const int BM = size * tile_size; 265 | const int BN = size * tile_size; 266 | const int BK = 8; 267 | const int TM = 8; 268 | const int TN = 8; 269 | dim3 block(size, size); 270 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 271 | tile_2d_split_kernel<<>>(M, N, K, alpha, A, B, beta, C); 272 | } 273 | 274 | 275 | template 276 | __global__ void tile_1d_split_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C) { 277 | const int block_row_thread = BN / TN; 278 | const int block_col_thread = BM / TM; 279 | const int thread_num = block_row_thread * block_col_thread; 280 | int num_shared_block = CEIL_DIV(K, BK); 281 | 282 | __shared__ float As[2][BK][BM]; // transpose shared A for avoid bank conflict, for double buffering 283 | __shared__ float Bs[2][BK][BN]; // for double buffering 284 | 285 | float accum[TM][TN] = {0.}; 286 | 287 | const int load_a_cache_time = (BK * BM) / thread_num / 4; // Each thread load 4 float 288 | const int load_b_cache_time = (BK * BN) / thread_num / 4; // Each thread load 4 float 289 | 290 | float load_a_cache[4 * load_a_cache_time]; 291 | // float load_a_cache[4]; 292 | // float load_b_cache[4 * load_b_cache_time]; 293 | 294 | A = &A[OFFSET(blockIdx.y * BM, 0, K)]; // Set block start position 295 | B = &B[OFFSET(0, blockIdx.x * BN, N)]; 296 | C = &C[OFFSET(blockIdx.y * BM, blockIdx.x * BN, N)]; 297 | 298 | int thread_id = threadIdx.y * blockDim.x + threadIdx.x; 299 | int a_tile_row = thread_id / (BK / 4); 300 | int a_tile_col = thread_id % (BK / 4) * 4; 301 | int a_tile_stride = BM / load_a_cache_time; 302 | // printf("A tile row, col, stride %d, %d, %d", a_tile_row, a_tile_col, a_tile_stride); 303 | 304 | int b_tile_row = thread_id / (BN / 4); 305 | int b_tile_col = thread_id % (BN / 4) * 4; 306 | int b_tile_stride = BK / load_b_cache_time; 307 | 308 | float As_cache[2][TM] = {0.}; // double buffering 309 | float Bs_cache[2][TN] = {0.}; // double buffering 310 | 311 | int write_idx = 0; 312 | 313 | #pragma unroll 314 | for (int i = 0; i < num_shared_block; ++i) { 315 | #pragma unroll 316 | for (int m = 0; m < BM; m += a_tile_stride) { 317 | int cache_idx = m / a_tile_stride * 4; 318 | FETCH_FLOAT4(load_a_cache[cache_idx]) = 319 | FETCH_FLOAT4(A[OFFSET(a_tile_row + m, a_tile_col, K)]); 320 | // Use load_a_cache for load 4 float at a time 321 | // As is saved as transpose matrix 322 | As[write_idx][a_tile_col][a_tile_row + m] = load_a_cache[cache_idx]; 323 | // 这里 stride = 128,有 shared memory bank 冲突 324 | As[write_idx][a_tile_col + 1][a_tile_row + m] = load_a_cache[cache_idx + 1]; 325 | As[write_idx][a_tile_col + 2][a_tile_row + m] = load_a_cache[cache_idx + 2]; 326 | As[write_idx][a_tile_col + 3][a_tile_row + m] = load_a_cache[cache_idx + 3]; 327 | } 328 | #pragma unroll 329 | for (int k = 0; k < BK; k += b_tile_stride) { 330 | FETCH_FLOAT4(Bs[write_idx][b_tile_row + k][b_tile_col]) = 331 | FETCH_FLOAT4(B[OFFSET(b_tile_row + k, b_tile_col, N)]); 332 | } 333 | __syncthreads(); 334 | A += BK; // Start position of next tile block to be processed 335 | B += BK * N; // Start position of next tile block to be processed 336 | 337 | #pragma unroll 338 | for (int k = 0; k < BK; ++k) { 339 | #pragma unroll 340 | // for (int m = 0; m < TM; m += 4) { 341 | for (int m = 0, mm = 0; m < BM && mm < TM; m += block_row_thread * 4, mm += 4) { 342 | // int A_row = threadIdx.y * TM + m; 343 | int A_row = m + threadIdx.y * 4; 344 | FETCH_FLOAT4(As_cache[write_idx][mm]) = FETCH_FLOAT4(As[write_idx][k][A_row]); 345 | } 346 | #pragma unroll 347 | for (int n = 0; n < TN; n += 4) { 348 | // for (int n = 0, nn = 0; n < BN && nn < TN; n += block_col_thread * 4, nn += 4) { 349 | int B_col = threadIdx.x * TN + n; 350 | // int B_col = n + threadIdx.x * 4; 351 | FETCH_FLOAT4(Bs_cache[write_idx][n]) = FETCH_FLOAT4(Bs[write_idx][k][B_col]); 352 | } 353 | #pragma unroll 354 | for (int m = 0; m < TM; ++m) { 355 | #pragma unroll 356 | for (int n = 0; n < TN; ++n) { 357 | accum[m][n] += As_cache[write_idx][m] * Bs_cache[write_idx][n]; 358 | } 359 | } 360 | } 361 | write_idx ^= 1; 362 | } 363 | 364 | #pragma unroll 365 | for (int m = 0; m < TM; m += 4) { 366 | int ROW = m / 4; 367 | int C_row = ROW * (block_row_thread * 4) + threadIdx.y * 4; 368 | #pragma unroll 369 | for (int n = 0; n < TN; n += 4) { 370 | int C_col = threadIdx.x * TN + n; 371 | #pragma unroll 372 | for (int i = 0; i < 4; ++i) { 373 | FETCH_FLOAT4(load_a_cache) = FETCH_FLOAT4(C[OFFSET(C_row + i, C_col, N)]); 374 | load_a_cache[0] = alpha * accum[m + i][n] + beta * load_a_cache[0]; 375 | load_a_cache[1] = alpha * accum[m + i][n + 1] + beta * load_a_cache[1]; 376 | load_a_cache[2] = alpha * accum[m + i][n + 2] + beta * load_a_cache[2]; 377 | load_a_cache[3] = alpha * accum[m + i][n + 3] + beta * load_a_cache[3]; 378 | FETCH_FLOAT4(C[OFFSET(C_row + i, C_col, N)]) = FETCH_FLOAT4(load_a_cache); 379 | } 380 | } 381 | } 382 | } 383 | 384 | void test_tile_1d_split_kernel(cublasHandle_t handle, int M, int N, int K, 385 | float alpha, float *A, float *B, float beta, float *C) { 386 | const int size = 16; 387 | const int tile_size = 8; 388 | const int BM = size * tile_size; 389 | const int BN = size * tile_size; 390 | const int BK = 8; 391 | const int TM = 8; 392 | const int TN = 8; 393 | dim3 block(size, size); 394 | dim3 grid(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); // note: change M and N here 395 | tile_1d_split_kernel<<>>(M, N, K, alpha, A, B, beta, C); 396 | } 397 | --------------------------------------------------------------------------------