├── .gitignore ├── gemm_gpu_mult_block_no_restrict.h ├── gemm_cpu_naive.h ├── gemm_cpu_simd.h ├── gemm_gpu_tiling.h ├── gemm_gpu_1thread.h ├── gemm_gpu_mult_block_no_restrict_reg.h ├── gemm_gpu_mult_block.h ├── gemm_gpu_mult_thread.h ├── README.md ├── Makefile ├── gemm_gpu_1thread.cu ├── gemm_gpu_mult_block.cu ├── gemm_gpu_mult_block_no_restrict.cu ├── gemm_cpu_simd.cc ├── gemm_gpu_mult_thread.cu ├── gemm_cpu_naive.cc ├── gemm_gpu_mult_block_no_restrict_reg.cu ├── gemm_gpu_tiling.cu ├── gemm_test.cc ├── LICENSE └── lecture.md /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode 2 | /gemm_test 3 | *.o 4 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block_no_restrict.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_mult_block_no_restrict( 4 | int* C, // [n, m] 5 | const int* A, // [n, k] 6 | const int* B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_cpu_naive.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_cpu_naive( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_cpu_simd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_cpu_simd( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_gpu_tiling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_tiling( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_gpu_1thread.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_1thread( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block_no_restrict_reg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_mult_block_no_restrict_reg( 4 | int* C, // [n, m] 5 | const int* A, // [n, k] 6 | const int* B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_mult_block( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /gemm_gpu_mult_thread.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gemm_gpu_mult_thread( 4 | int* __restrict__ C, // [n, m] 5 | const int* __restrict__ A, // [n, k] 6 | const int* __restrict__ B, // [k, m] 7 | const int n, 8 | const int m, 9 | const int k 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUDA - From Correctness to Performance 2 | 3 | This repo includes codes & examples for "CUDA - From Correctness to Performance". 4 | 5 | The lecture can be found at https://wiki.lcpu.dev/zh/hpc/from-scratch/cuda or [here](lecture.md) 6 | 7 | ## How to Build 8 | 9 | Make sure you have installed the CUDA toolkit, and a CUDA-compatible GPU is available. 10 | 11 | Run `make all` to build this repo. 12 | 13 | ## How to Use 14 | 15 | Usage: 16 | 17 | ```bash 18 | ./gemm_test [implementation] 19 | ``` 20 | 21 | If `implementation` is not specified, all implementations will be benchmarked. 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CC = nvcc 2 | CXXFLAGS = -O3 3 | 4 | DEPS = gemm_cpu_naive.h gemm_cpu_simd.h gemm_gpu_1thread.h gemm_gpu_mult_thread.h gemm_gpu_mult_block.h gemm_gpu_mult_block_no_restrict.h gemm_gpu_mult_block_no_restrict_reg.h gemm_gpu_tiling.h 5 | OBJS = gemm_cpu_naive.o gemm_cpu_simd.o gemm_test.o gemm_gpu_1thread.o gemm_gpu_mult_thread.o gemm_gpu_mult_block.o gemm_gpu_mult_block_no_restrict.o gemm_gpu_mult_block_no_restrict_reg.o gemm_gpu_tiling.o 6 | 7 | .PHONY: all clean 8 | 9 | all: gemm_test 10 | 11 | %.o: %.cc $(DEPS) # for .cc files 12 | $(CC) -c $(CXXFLAGS) $< -o $@ 13 | 14 | %.o: %.cu $(DEPS) # for .cu files 15 | $(CC) -c $(CXXFLAGS) $< -o $@ 16 | 17 | gemm_test: $(OBJS) 18 | $(CC) $(CXXFLAGS) $^ -o gemm_test 19 | 20 | clean: 21 | rm -f *.o gemm_test -------------------------------------------------------------------------------- /gemm_gpu_1thread.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_1thread.h" 2 | 3 | #include 4 | 5 | // gemm_gpu_1thread - GEMM on GPU, using only one thread 6 | __global__ 7 | void gemm_gpu_1thread_kernel( 8 | int* __restrict__ C, // [n, m], on gpu 9 | const int* __restrict__ A, // [n, k], on gpu 10 | const int* __restrict__ B, // [k, m], on gpu 11 | const int n, 12 | const int m, 13 | const int k 14 | ) { 15 | for (int i = 0; i < n; ++i) 16 | for (int j = 0; j < m; ++j) { 17 | int res = 0; 18 | for (int l = 0; l < k; ++l) { 19 | res += A[i * k + l] * B[l * m + j]; 20 | } 21 | C[i * m + j] = res; 22 | } 23 | } 24 | 25 | void gemm_gpu_1thread( 26 | int* __restrict__ C, // [n, m], on gpu 27 | const int* __restrict__ A, // [n, k], on gpu 28 | const int* __restrict__ B, // [k, m], on gpu 29 | const int n, 30 | const int m, 31 | const int k 32 | ) { 33 | gemm_gpu_1thread_kernel<<<1, 1>>>(C, A, B, n, m, k); 34 | } 35 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_mult_block.h" 2 | 3 | #include 4 | 5 | // gemm_gpu_mult_block - GEMM on GPU, using many blocks 6 | // The block size is N 7 | // The grid size is M 8 | __global__ 9 | void gemm_gpu_mult_block_kernel( 10 | int* __restrict__ C, // [n, m], on gpu 11 | const int* __restrict__ A, // [n, k], on gpu 12 | const int* __restrict__ B, // [k, m], on gpu 13 | const int n, 14 | const int m, 15 | const int k 16 | ) { 17 | const int i = threadIdx.x; 18 | const int j = blockIdx.x; 19 | for (int l = 0; l < k; ++l) { 20 | C[i * m + j] += A[i * k + l] * B[l * m + j]; 21 | } 22 | } 23 | 24 | void gemm_gpu_mult_block( 25 | int* __restrict__ C, // [n, m], on gpu 26 | const int* __restrict__ A, // [n, k], on gpu 27 | const int* __restrict__ B, // [k, m], on gpu 28 | const int n, 29 | const int m, 30 | const int k 31 | ) { 32 | gemm_gpu_mult_block_kernel<<>>(C, A, B, n, m, k); 33 | } 34 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block_no_restrict.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_mult_block.h" 2 | 3 | #include 4 | 5 | // gemm_gpu_mult_block_no_restrict - GEMM on GPU, using many blocks 6 | // and without the __restrict__ keyword 7 | // The block size is N 8 | // The grid size is M 9 | __global__ 10 | void gemm_gpu_mult_block_no_restrict_kernel( 11 | int* C, // [n, m], on gpu 12 | const int* A, // [n, k], on gpu 13 | const int* B, // [k, m], on gpu 14 | const int n, 15 | const int m, 16 | const int k 17 | ) { 18 | const int i = threadIdx.x; 19 | const int j = blockIdx.x; 20 | for (int l = 0; l < k; ++l) { 21 | C[i * m + j] += A[i * k + l] * B[l * m + j]; 22 | } 23 | } 24 | 25 | void gemm_gpu_mult_block_no_restrict( 26 | int* C, // [n, m], on gpu 27 | const int* A, // [n, k], on gpu 28 | const int* B, // [k, m], on gpu 29 | const int n, 30 | const int m, 31 | const int k 32 | ) { 33 | gemm_gpu_mult_block_no_restrict_kernel<<>>(C, A, B, n, m, k); 34 | } 35 | -------------------------------------------------------------------------------- /gemm_cpu_simd.cc: -------------------------------------------------------------------------------- 1 | #include "gemm_cpu_simd.h" 2 | 3 | // gemm_cpu_simd - A SIMD gemm (GEneral Matrix Multiply) implementation on CPU 4 | // Input: 5 | // - A: A n x k matrix 6 | // - B: A k x m matrix 7 | // - n: The number of rows of A 8 | // - m: The number of columns of B 9 | // - k: The number of columns of A and the number of rows of B 10 | // Output: 11 | // - C: A n x m matrix. The result of A * B 12 | // Requirements: 13 | // - Please make sure C is initialized to 0 before calling this function 14 | __attribute__((optimize("O3"))) // Enforce O3 optimization to utilize loop unrolling and SIMD 15 | void gemm_cpu_simd( 16 | int* __restrict__ C, // [n, m] 17 | const int* __restrict__ A, // [n, k] 18 | const int* __restrict__ B, // [k, m] 19 | const int n, 20 | const int m, 21 | const int k 22 | ) { 23 | for (int i = 0; i < n; ++i) 24 | for (int l = 0; l < k; ++l) 25 | for (int j = 0; j < m; ++j) { 26 | C[i * m + j] += A[i * k + l] * B[l * m + j]; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /gemm_gpu_mult_thread.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_mult_thread.h" 2 | 3 | #include 4 | 5 | // gemm_gpu_mult_thread - GEMM on GPU, using only one block 6 | // The block size is N 7 | __global__ 8 | void gemm_gpu_mult_thread_kernel( 9 | int* __restrict__ C, // [n, m], on gpu 10 | const int* __restrict__ A, // [n, k], on gpu 11 | const int* __restrict__ B, // [k, m], on gpu 12 | const int n, 13 | const int m, 14 | const int k 15 | ) { 16 | const int i = threadIdx.x; 17 | for (int j = 0; j < m; ++j) { 18 | int res = 0; 19 | for (int l = 0; l < k; ++l) { 20 | res += A[i * k + l] * B[l * m + j]; 21 | } 22 | C[i * m + j] = res; 23 | } 24 | } 25 | 26 | void gemm_gpu_mult_thread( 27 | int* __restrict__ C, // [n, m], on gpu 28 | const int* __restrict__ A, // [n, k], on gpu 29 | const int* __restrict__ B, // [k, m], on gpu 30 | const int n, 31 | const int m, 32 | const int k 33 | ) { 34 | gemm_gpu_mult_thread_kernel<<<1, n>>>(C, A, B, n, m, k); 35 | } 36 | -------------------------------------------------------------------------------- /gemm_cpu_naive.cc: -------------------------------------------------------------------------------- 1 | #include "gemm_cpu_naive.h" 2 | 3 | // gemm_cpu_naive - A naive gemm (GEneral Matrix Multiply) implementation on CPU 4 | // Input: 5 | // - A: A n x k matrix 6 | // - B: A k x m matrix 7 | // - n: The number of rows of A 8 | // - m: The number of columns of B 9 | // - k: The number of columns of A and the number of rows of B 10 | // Output: 11 | // - C: A n x m matrix. The result of A * B 12 | // Requirements: 13 | // - Please make sure C is initialized to 0 before calling this function 14 | __attribute__((optimize("O1"))) // Enforce O1 optimization to avoid loop unrolling and SIMD 15 | void gemm_cpu_naive( 16 | int* __restrict__ C, // [n, m] 17 | const int* __restrict__ A, // [n, k] 18 | const int* __restrict__ B, // [k, m] 19 | const int n, 20 | const int m, 21 | const int k 22 | ) { 23 | for (int i = 0; i < n; ++i) 24 | for (int l = 0; l < k; ++l) 25 | for (int j = 0; j < m; ++j) { 26 | C[i * m + j] += A[i * k + l] * B[l * m + j]; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /gemm_gpu_mult_block_no_restrict_reg.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_mult_block.h" 2 | 3 | #include 4 | 5 | // gemm_gpu_mult_block_no_restrict_reg - GEMM on GPU, using many blocks 6 | // and without the __restrict__ keyword 7 | // and stores the intermediate results in a register 8 | // The block size is N 9 | // The grid size is M 10 | __global__ 11 | void gemm_gpu_mult_block_no_restrict_reg_kernel( 12 | int* C, // [n, m], on gpu 13 | const int* A, // [n, k], on gpu 14 | const int* B, // [k, m], on gpu 15 | const int n, 16 | const int m, 17 | const int k 18 | ) { 19 | const int i = threadIdx.x; 20 | const int j = blockIdx.x; 21 | int res = 0; 22 | for (int l = 0; l < k; ++l) { 23 | res += A[i * k + l] * B[l * m + j]; 24 | } 25 | C[i * m + j] = res; 26 | } 27 | 28 | void gemm_gpu_mult_block_no_restrict_reg( 29 | int* C, // [n, m], on gpu 30 | const int* A, // [n, k], on gpu 31 | const int* B, // [k, m], on gpu 32 | const int n, 33 | const int m, 34 | const int k 35 | ) { 36 | gemm_gpu_mult_block_no_restrict_reg_kernel<<>>(C, A, B, n, m, k); 37 | } 38 | -------------------------------------------------------------------------------- /gemm_gpu_tiling.cu: -------------------------------------------------------------------------------- 1 | #include "gemm_gpu_tiling.h" 2 | 3 | #include 4 | 5 | #include 6 | 7 | constexpr int TILE_SIZE = 32; 8 | 9 | // gemm_gpu_tiling - GEMM on GPU, using tiling & shared memory to optimize 10 | // global memory accesses 11 | __global__ 12 | void gemm_gpu_tiling_kernel( 13 | int* __restrict__ C, // [n, m], on gpu 14 | const int* __restrict__ A, // [n, k], on gpu 15 | const int* __restrict__ B, // [k, m], on gpu 16 | const int n, 17 | const int m, 18 | const int k 19 | ) { 20 | // We copy the tile from a/b into shared memory, and then do the calculation 21 | __shared__ int a_tile[TILE_SIZE][TILE_SIZE]; 22 | __shared__ int b_tile[TILE_SIZE][TILE_SIZE]; 23 | int my_c_result = 0; 24 | for (int tile_index = 0; tile_index < k/TILE_SIZE; ++tile_index) { 25 | // Step 1. Load the tile from a/b into a/b_tile 26 | a_tile[threadIdx.y][threadIdx.x] = A[(blockIdx.x*TILE_SIZE + threadIdx.y)*k + (tile_index*TILE_SIZE + threadIdx.x)]; 27 | b_tile[threadIdx.y][threadIdx.x] = B[(tile_index*TILE_SIZE + threadIdx.y)*m + (blockIdx.y*TILE_SIZE + threadIdx.x)]; 28 | __syncthreads(); 29 | // Step 2. Calculate the contribution to my_c_result 30 | for (int i = 0; i < TILE_SIZE; ++i) { 31 | my_c_result += a_tile[threadIdx.y][i] * b_tile[i][threadIdx.x]; 32 | } 33 | __syncthreads(); 34 | } 35 | // Step 3. Store my_c_result 36 | C[(blockIdx.x*TILE_SIZE + threadIdx.y)*m + (blockIdx.y*TILE_SIZE + threadIdx.x)] = my_c_result; 37 | } 38 | 39 | void gemm_gpu_tiling( 40 | int* __restrict__ C, // [n, m], on gpu 41 | const int* __restrict__ A, // [n, k], on gpu 42 | const int* __restrict__ B, // [k, m], on gpu 43 | const int n, 44 | const int m, 45 | const int k 46 | ) { 47 | assert (n % TILE_SIZE == 0); 48 | assert (m % TILE_SIZE == 0); 49 | assert (k % TILE_SIZE == 0); 50 | dim3 grid_dim = dim3(n / TILE_SIZE, m / TILE_SIZE); 51 | dim3 block_dim = dim3(TILE_SIZE, TILE_SIZE); 52 | gemm_gpu_tiling_kernel<<>>(C, A, B, n, m, k); 53 | } 54 | -------------------------------------------------------------------------------- /gemm_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "gemm_cpu_naive.h" 11 | #include "gemm_cpu_simd.h" 12 | #include "gemm_gpu_1thread.h" 13 | #include "gemm_gpu_mult_thread.h" 14 | #include "gemm_gpu_mult_block.h" 15 | #include "gemm_gpu_mult_block_no_restrict.h" 16 | #include "gemm_gpu_mult_block_no_restrict_reg.h" 17 | #include "gemm_gpu_tiling.h" 18 | 19 | // gemm_impl_t - A function pointer type for gemm implementations 20 | typedef void (*gemm_impl_t)( 21 | int* __restrict__ C, 22 | const int* __restrict__ A, 23 | const int* __restrict__ B, 24 | const int n, 25 | const int m, 26 | const int k 27 | ); 28 | 29 | struct GemmImpl { 30 | std::string name; 31 | gemm_impl_t impl; 32 | bool is_gpu; 33 | }; 34 | 35 | // All GEMM implementations to benchmark 36 | std::vector gemm_impls = { 37 | { "cpu_naive", gemm_cpu_naive, false }, 38 | { "cpu_simd", gemm_cpu_simd, false }, 39 | { "gpu_1thread", gemm_gpu_1thread, true }, 40 | { "gpu_mult_thread", gemm_gpu_mult_thread, true }, 41 | { "gpu_mult_block", gemm_gpu_mult_block, true }, 42 | { "gpu_mult_block_no_restrict", gemm_gpu_mult_block_no_restrict, true }, 43 | { "gpu_mult_block_no_restrict_reg", gemm_gpu_mult_block_no_restrict_reg, true }, 44 | { "gpu_tiling", gemm_gpu_tiling, true } 45 | }; 46 | 47 | // cuda_sync_check_error - Sync with the CUDA device, check if there 48 | // is any error, and print the error message if there is any. 49 | void cuda_sync_check_error_helper(const char* filename, const int line) { 50 | cudaDeviceSynchronize(); 51 | cudaError_t error = cudaGetLastError(); 52 | if (error != cudaSuccess) { 53 | printf("CUDA error at %s:%d: %s\n", filename, line, cudaGetErrorString(error)); 54 | exit(1); 55 | } 56 | } 57 | #define cuda_sync_check_error() cuda_sync_check_error_helper(__FILE__, __LINE__) 58 | 59 | constexpr int BENCHMARK_ROUNDS = 8; 60 | // benchmark_gemm_impl - Benchmark a gemm implementation, return the 61 | // avg time usage on BENCKMARK_ROUNDS rounds 62 | double benchmark_gemm_impl( 63 | GemmImpl gemm_impl, 64 | const int n, 65 | const int m, 66 | const int k 67 | ) { 68 | // Prepare test data 69 | 70 | // Allocate A, B, C, and C_ref on CPU 71 | int* A = new int[n * k]; 72 | int* B = new int[k * m]; 73 | int* C = new int[n * m]; 74 | int* C_ref = new int[n * m]; 75 | 76 | // Allocate A_gpu, B_gpu, and C_gpu on GPU 77 | int* A_gpu; 78 | int* B_gpu; 79 | int* C_gpu; 80 | cudaMalloc(&A_gpu, sizeof(int) * n * k); 81 | cudaMalloc(&B_gpu, sizeof(int) * k * m); 82 | cudaMalloc(&C_gpu, sizeof(int) * n * m); 83 | 84 | // Initialize A and B and copy them to GPU 85 | for (int i = 0; i < n * k; ++i) A[i] = rand() % 1000; 86 | for (int i = 0; i < k * m; ++i) B[i] = rand() % 1000; 87 | cudaMemcpy(A_gpu, A, sizeof(int) * n * k, cudaMemcpyHostToDevice); 88 | cudaMemcpy(B_gpu, B, sizeof(int) * k * m, cudaMemcpyHostToDevice); 89 | 90 | // Initialize C_ref 91 | memset(C_ref, 0, sizeof(int) * n * m); 92 | gemm_cpu_naive(C_ref, A, B, n, m, k); 93 | 94 | // run_once: Run the gemm_impl once, and return its time usage (in microseconds (us)) 95 | std::function run_once = [&]() -> long { 96 | if (gemm_impl.is_gpu) { 97 | cudaMemset(C_gpu, 0, sizeof(int) * n * m); 98 | auto start = std::chrono::high_resolution_clock::now(); 99 | gemm_impl.impl(C_gpu, A_gpu, B_gpu, n, m, k); 100 | cuda_sync_check_error(); 101 | auto end = std::chrono::high_resolution_clock::now(); 102 | return std::chrono::duration_cast(end - start).count(); 103 | } else { 104 | memset(C, 0, sizeof(int) * n * m); 105 | auto start = std::chrono::high_resolution_clock::now(); 106 | gemm_impl.impl(C, A, B, n, m, k); 107 | auto end = std::chrono::high_resolution_clock::now(); 108 | return std::chrono::duration_cast(end - start).count(); 109 | } 110 | }; 111 | 112 | // Warm up 113 | printf("Warming up...\n"); 114 | run_once(); 115 | 116 | // Verift its correctness 117 | if (gemm_impl.is_gpu) { 118 | cudaMemcpy(C, C_gpu, sizeof(int) * n * m, cudaMemcpyDeviceToHost); 119 | } 120 | for (int i = 0; i < n * m; ++i) { 121 | if (C[i] != C_ref[i]) { 122 | printf("Verification failed!\n"); 123 | printf("C[%d, %d] = %d, C_ref[%d, %d] = %d\n", i / m, i % m, C[i], i / m, i % m, C_ref[i]); 124 | return -1; 125 | } 126 | } 127 | std::cout << "Verification passed!" << std::endl; 128 | 129 | // Warm up again since correct verification may corrupt cache 130 | printf("Warming up (again)...\n"); 131 | run_once(); 132 | 133 | // Benchmark 134 | long total_time_usage = 0; 135 | for (int round = 0; round < BENCHMARK_ROUNDS; ++round) { 136 | long time_usage = run_once(); 137 | printf("Round %d: %ld us\n", round, time_usage); 138 | total_time_usage += time_usage; 139 | } 140 | double avg_time_usage = total_time_usage / (double)BENCHMARK_ROUNDS; 141 | 142 | // Free memory 143 | delete[] A; 144 | delete[] B; 145 | delete[] C; 146 | delete[] C_ref; 147 | cudaFree(A_gpu); 148 | cudaFree(B_gpu); 149 | cudaFree(C_gpu); 150 | 151 | return avg_time_usage; 152 | } 153 | 154 | int main(int argc, char* argv[]) { 155 | if (argc != 4 && argc != 5) { 156 | printf("Usage: %s [implementation]\n", argv[0]); 157 | exit(1); 158 | } 159 | 160 | // Parse command line arguments 161 | int n = atoi(argv[1]); 162 | int m = atoi(argv[2]); 163 | int k = atoi(argv[3]); 164 | std::string impl = argc == 5 ? argv[4] : "*"; 165 | assert (n > 0 && m > 0 && k > 0); 166 | 167 | // We allocate a small block of memory on GPU here to initialize the 168 | // CUDA context. If we does not do this and the user hasn't enabled 169 | // GPU's persistent mode, then the first CUDA call will take a long 170 | // time to finish 171 | int* dummy; 172 | cudaMalloc(&dummy, sizeof(int)); 173 | 174 | std::vector> results; 175 | for (auto gemm_impl : gemm_impls) { 176 | if (impl == "*" || gemm_impl.name == impl) { 177 | printf("----------------\n"); 178 | printf("Benchmarking %s...\n", gemm_impl.name.c_str()); 179 | double avg_time_usage = benchmark_gemm_impl(gemm_impl, n, m, k); 180 | printf("Average time usage: %lf us\n", avg_time_usage); 181 | results.push_back({ gemm_impl.name, avg_time_usage }); 182 | } 183 | } 184 | 185 | // Print results 186 | printf("----------------\n"); 187 | printf("Results:\n"); 188 | for (auto result : results) { 189 | printf("%16s %16.2lf us\n", result.first.c_str(), result.second); 190 | } 191 | 192 | return 0; 193 | } 194 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /lecture.md: -------------------------------------------------------------------------------- 1 | # Introduction to CUDA Programming: From Correctness to Performance 2 | 3 | # Overview 4 | 5 | 本文是 HPC From Scratch 的第六节课。本文将从 GPU 的结构与 CUDA 的基本概念讲起,带领大家写出自己的第一个正确的 CUDA 程序,并教大家一些基本的优化技巧,带领大家优化自己的 CUDA 程序(正如标题所示,From Correctness to Performamce)。 6 | 7 | 本文分为三部分:Part 0 简要介绍了为什么 GPU 能在许多任务上取得千倍的加速比;Part 1 介绍了 GPU 编程的基本概念,以及如何写出第一个 CUDA 程序;Part 2 则是一些基本的优化技巧。如果你是纯新手,那么建议从头开始阅读;如果您已经有了一些 CUDA 编程经验,那么建议您只看 Part 2。 8 | 9 | > Aside | 拓展内容:本文中标注了 "Aside" 的内容为拓展内容,我认为他们比较有趣,但其与课程主线关联不大。 10 | {.is-info} 11 | 12 | 13 | 答疑:有问题的同学可以在 Linux 社团或 HPC From Scratch 的微信群中提问。如果您觉得讲义中的内容有问题,请联系 shengyu.liu@stu.pku.edu.cn。 14 | 15 | 示例代码、课后作业所在的 git 仓库位于 https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code 16 | 17 | # Part 0. Introduction 18 | 19 | ## Why GPU? 20 | 21 | 首先,我们来思考一个问题:既然 CPU 已经可以执行绝大部分计算了,我们为什么还需要使用 GPU 来执行计算呢? 22 | 23 | 答案很简单:因为效率。某些在 CPU 上需要花很长时间才能执行完成的操作,在 GPU 上可能可以获得十倍、百倍甚至千倍的效率提升。举个例子,使用过 PyTorch 的同学应该体会过,同一个神经网络,在 CPU 上运行所需要的时间是 GPU 的数十倍甚至数百倍。 24 | 25 | 而且,正如每个人都有自己的长处与短处一样,GPU 也有自己所擅长的计算任务。GPU 底层的硬件设计导致它特别擅长于执行逻辑简单但是并行程度极高的计算任务。这样的任务包括: 26 | 27 | - 向量加法 28 | - 矩阵乘法 29 | - 有限元解偏微分方程 30 | - 求出一个向量中所有元素的和 31 | - ... 32 | 33 | ## GPU 的哲学 34 | 35 | GPU 的哲学是: 36 | 37 | 1. 通过 SIMD-Style 地缩减控制单元(多个 CUDA Core 共用一个控制单元),把芯片面积让给大量计算单元,实现高吞吐量与 Massive parallel 38 | 39 | 2. 通过大规模、高带宽的吞吐,隐藏访存延迟。GPU 中有大量的寄存器,上下文切换很快,可以在多个 kernel 之间迅速切换。举个例子,假如某个 kernel 计算 1ms,访存10ms。在等待访存的这10ms内,可以运行其他 10 个 kernel。这样虽然这个 kernel 得到结果的速度没有加快,但是运算量就有十倍的增加。 40 | 41 | (这一段看不懂没关系,可以以后回过头来想) 42 | 43 | ## GPU 为什么擅长此类计算任务? 44 | 45 | 那么,GPU 为什么能在此类任务上取得如此高的性能提升呢? 46 | 47 | 我们打个比方:如果我们现在想证明一个复杂的数学命题,那么是一位陈景润解得快,还是 100 名大一学生解得快?大概率是前者,因为陈景润先生有着深厚的功底与丰富的经验,并且“证明命题”这一过程很难并行。 48 | 49 | 但如果我们现在想要计算 10000 道 100 以内的乘除法呢?那么大概率是 100 名大一学生算得快。因为,虽然一名大一学生计算一道 100 以内的乘除法的速度比不上陈景润先生计算一道 100 以内的乘除法的速度,但100 名大一学生一起上,速度一定会比一位陈景润先生要快。**大量简单低效的计算资源的堆砌 + 高并行度 = 快。** 50 | 51 | CPU 与 GPU 的区别就好像上文中的一位陈景润先生与 100 名大一学生的区别。CPU 把大量的晶体管耗费在了分支预测、乱序执行等控制单元上,分配给运算单元(ALU 等)的晶体管较少,其目标是尽量提高串行程序的运行速度;而 GPU 则是使用大量的晶体管来堆砌大量的运算单元,同时让许多(不严谨地说,是 32 个)运算单元共享同一个控制单元,以节约晶体管。 52 | 53 | ![CPU 与 GPU 的结构区别](/images/hpc/cpu-gpu-arch-diff.png) 54 | 55 | *Image source: https://developer.nvidia.com/blog/cuda-refresher-reviewing-the-origins-of-gpu-computing/* 56 | 57 | > Aside | Intel Xeon Phi: CPU 是”少量大核心“,适合执行串行任务;GPU 是”大量小核心“,适合执行逻辑简单、并行度大的任务;那么”中量中核心“会有怎样的效果呢?感兴趣的同学可以搜索一下 Intel Xeon Phi。 58 | {.is-info} 59 | 60 | 也就是说,CPU 的执行是单指令单数据(Single Instruction Single Data, SISD)的,其每一条指令仅处理一个数据;而 GPU 的执行则是单指令多数据(Single Instruction Multiple Data, SIMD)的,其一条指令会被同时作用于多个数据。 61 | 62 | > Aside | CPU SIMD:CPU 也有 SIMD 指令集(如 AVX512),可以实现单指令处理多数据,但其并行度还是远远落后于 GPU。 63 | {.is-info} 64 | 65 | 以向量加法为例,它逻辑很简单(只需要把两个向量的对应位置加起来即可),且并行度极高(可以同时计算输出向量每个位置上的结果)。如果使用 CPU,那么我需要依次计算输出向量每个位置的结果;但如果使用 GPU,我可以同时计算输出向量每个位置的结果,进而大大提高了速度。 66 | 67 | 这里有一个形象的解释 CPU 与 GPU 工作原理区别的视频:[Link](https://www.bilibili.com/video/BV1ry4y1y7KZ)。 68 | 69 | ## 为什么 GPU 会这样设计 70 | 71 | 可是... GPU,Graphic Processing Unit,“显卡”,原本不是用来处理图像的吗?为什么它会这么设计? 72 | 73 | 我们先来考虑 3D 游戏画面的渲染管线。我们可以将这个管线分为三个部分:Vertex Mapping, Fragment Generation 与 Fragment Merging。 74 | 75 | 在很久之前,每一家 GPU(当时还叫“图形加速卡”)的厂商的做法都是:让一部分电路专门负责 Vertex Mapping,一部分电路专门负责 Fragment Generation,一部分电路专门负责 Fragment Merging。但是,这样做有一个问题:每一款游戏对不同的处理步骤的负载是不同的,比如游戏 A 可能给 Vertex Mapping 单元的负载较高,导致其成为瓶颈,同时其他两部分电路有空闲;游戏 B 可能给 Fragment Generation 的负载较高,导致其成为瓶颈。 76 | 77 | 在 2006 年,Nvidia 推翻了传统而设计,发布了一个革命性的 GPU 架构 - Tesla。在 Tesla 架构中,没有了专门负责处理某一个步骤的硬件单元,取而代之的则是 Stream Multiprocessor (SM) 。每个 SM 都像一个小型 CPU 一样,可以执行其支持的指令集中的任何程序。这也就代表着,每一个 SM 都有能力执行渲染管线中的每个部分。这种设计避免了某一个步骤成为瓶颈而其他步骤的运算单元闲置的情况:我只需要根据不同的游戏负载,为每个步骤分配一定数量的 SM 即可。 78 | 79 | 同时,Nvidia 发现,Tesla 架构不仅可以执行图像渲染方面的计算,其在通用计算方面也很有潜力(毕竟,SM 可以执行任何指令集支持的指令)。所以,Nvidia 也在同时发布了一套可以在 GPU 上执行通用计算的工具链 —— Compute Unified Device Architecture (CUDA)。从此,GPU 成为了 General-Purpose Graphics Processing Unit (GPGPU)。 80 | 81 | > Aside: 想要深入了解 Nvidia 的 GPU 发展史的同学可以看[这篇文章](https://fabiensanglard.net/cuda/)与 Bilibili UP 主“极客湾”的“显卡发展史系列”:[1](https://www.bilibili.com/video/BV1Hb41177JB) [2](https://www.bilibili.com/video/BV1C4411J7cR) [3](https://www.bilibili.com/video/BV1YJ411h7aY)。 82 | {.is-info} 83 | 84 | *注:我(intlsy)并不是很懂图像渲染。如果有同学发现这部分内容有误,请联系 shengyu.liu@stu.pku.edu.cn。* 85 | 86 | ## Takeaway 87 | 88 | - CPU 的每个核心都很“大”,但核心数较少;GPU 每个核心都很“小”,但是核心数非常多。 89 | - CPU 和 GPU 的结构决定了:CPU 适合执行串行程序,GPU 适合执行并行度极高的程序。 90 | - GPU 之所以这样设计,与 GPU 的发展历史息息相关 91 | 92 | # Part 1. Correctness 93 | 94 | 那么,掌握了 GPU 的基础知识,我们接下来就要开始写代码啦! 95 | 96 | 接下来我将以矩阵乘法为例,带领大家逐渐编写一个正确性无误且性能勉强过关的矩阵乘法。 97 | 98 | > Aside | cuBLAS:作为一个十分常用的操作,市面上已经有很多成熟高效的矩阵乘法库,比如实现了线性代数中大部分计算的 [Nvidia cuBLAS](https://developer.nvidia.com/cublas)。 99 | {.is-info} 100 | 101 | > Aside | Tensor Core:因为矩阵乘法这个操作太普遍了,所以除了我们即将使用的 CUDA Core 外,Nvidia 还在显卡中设计了另一种计算单元 - [Tensor Core](https://www.nvidia.com/en-us/data-center/tensor-cores)。它可以以比 CUDA Core 高数十倍的速率计算矩阵乘法,且支持 AI 中常用的混合精度计算。 102 | {.is-info} 103 | 104 | 首先,请确认你正在使用的机器上面有 Nvidia 的显卡,且你能调用 CUDA 编译器(一般来说,直接在命令行中输入 `nvcc` 即可)。 105 | 106 | > Aside | CUDA C++:CUDA 使用的是经过 Nvidia 魔改的 C++,其包含一些原生 C++ 不支持的语法,故不能使用 `g++` 等常规编译器。 107 | {.is-info} 108 | 109 | ## Step 0. 部署示例代码 110 | 111 | 首先,我们先部署好的示例代码。示例代码中包含一个用来对照的 CPU 上的 GEMM 实现、用来做性能测试的代码、以及 GEMM 在 GPU 上的若干种实现。 112 | 113 | *注:GEMM 代表 GEneral Matrix Multiplication,矩阵乘法。* 114 | 115 | 先 clone [这个仓库](https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code)。随后使用 `make all && ./gemm_test 64 64 64` 来运行它。理想情况下它应该输出类似于这样的东西: 116 | 117 | ```plain 118 | Benchmarking gpu_mult_block... 119 | Warming up... 120 | Verification passed! 121 | Warming up (again)... 122 | Round 0: 12 us 123 | Round 1: 12 us 124 | Round 2: 12 us 125 | Round 3: 12 us 126 | Round 4: 12 us 127 | Round 5: 12 us 128 | Round 6: 12 us 129 | Round 7: 12 us 130 | Average time usage: 12.000000 us 131 | ``` 132 | 133 | 这个仓库中,`gemm_test.cc` 是主程序,其包含 `main` 函数以及与性能测试(benchmark)相关的逻辑。其他的 `gemm_XXX` 中包含了各种各样的 GEMM 的实现,比如 CPU 上的简单实现 `gemm_cpu_naive`、GPU 上的多 thread 单 block 实现 `gemm_gpu_mult_thread` 等。 134 | 135 | 你可以在编译后使用 `./gemm_test [implementation]` 来 benchmark 所有的或特定的 GEMM 实现。其中,`n`, `m`, `k` 分别代表矩阵的三个维度(假设我们要计算 $C = A \times B$,那么 $A$ 矩阵的大小为 $n \times k$,$B$ 矩阵的大小为 $k \times m$,$C$ 矩阵的大小为 $n \times m$),`implementation` 代表你要 benchmark 的 GEMM 实现的名字(留空以 benchmark 所有的 GEMM 实现)。你可以在 `gemm_test.cc` 的开头位置找到所有的 GEMM 实现及其名字。 136 | 137 | *注:如果你每次开始运行程序的时候,程序都要卡 $1 \sim 2$ 秒才有输出,那么可能是因为你没有开启 GPU 的 Persistent Mode。* 138 | 139 | ## Step 1. Your First CUDA Kernel 140 | 141 | 首先,我们需要掌握一个基本概念,CUDA Kernel: 142 | 143 | CUDA 的设计思想大致是:向显卡提交一个又一个任务,每一个任务都形如“给定一个函数,与调用它的参数,请在显卡上运行这个函数”。我们一般**称这种“在显卡上运行的函数”叫做 CUDA Kernel**。仔细想想,这种设计很合理嘛!毕竟现在 GPU 是“加速器”,其仅负责加速程序中的某一些部分,其他的控制流程与计算还是要由 CPU 来做的。 144 | 145 | 所以,现在的问题就是: 146 | 147 | - 如何定义(创建)一个 CUDA Kernel? 148 | - 如何调用这个 CUDA Kernel? 149 | 150 | 首先是如何定义 CUDA Kernel 的问题。CUDA C++ 中有三类函数: 151 | 152 | - `__host__`: 这类函数与正常的函数没有区别。其只能被 host 上执行的函数(`__host__`)调用,并在 host 上执行。 153 | - `__global__`: 这类函数可以被任何函数调用,并在 device 上执行。 154 | - `__device__`: 这类函数只能被 device 上执行的函数(`__device__` 或 `__global__`)调用,并在 device 上执行。 155 | 156 | 不难发现,CUDA Kernel 不就正属于 `__global__` 类嘛! 157 | 158 | 在 CUDA 中,我们可以把一个函数的类别放在函数的返回值类型的前面,以告知编译器这个函数属于哪一类。没有定义属于哪一类的函数默认为 `__host__`。如下例: 159 | 160 | ```cpp 161 | // 下面这句话定义了一个 __global__ 类型的函数。该函数将在 GPU 上运行,并可以被任意函数调用。 162 | __global__ void gemm_gpu_1thread_kernel(int* C, const int* A, const int* B, int n, int m, int k) { 163 | } 164 | // 下面这句话定义了一个 __device__ 类型的函数。 165 | __device__ int mult_helper(int a, int b) { 166 | } 167 | // 下面这句话定义了一个 __host__ 类型的函数。 168 | __host__ void prepare_input() { 169 | } 170 | // 不加任何修饰的函数默认为 __host__ 类型。 171 | void func() { 172 | } 173 | ``` 174 | 175 | *注:在 CUDA 的编程模型中,一般称 CPU 为 Host,GPU 为 Device。* 176 | 177 | 因此,我们只要在一个函数的定义的最前面加上 `__global__`,它就是一个 Kernel 啦! 178 | 179 | 那么,如何调用 CUDA Kernel 呢?与 C++ 中调用函数的方式大同小异,不过要在函数名与参数列表的中间加上一个 `<<>>`(现阶段,先认为 `GRID_DIM` 与 `BLOCK_DIM` 均为 1)。举个例子: 180 | 181 | ```cpp 182 | // 下面这句话调用了一个名为 gemm_gpu_1thread_kernel 的 kernel 183 | gemm_gpu_1thread_kernel<<<1, 1>>>(C, A, B, n, m, k); 184 | ``` 185 | 186 | 现在,请打开示例代码中的 [gemm_gpu_1thread.cu](https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code/blob/master/gemm_gpu_1thread.cu),并阅读这份代码。试着找找:这份代码定义了哪个 CUDA Kernel?哪一行代码调用了这个 CUDA Kernel? 187 | 188 | 那么你现在信心满满地写了一个 CUDA Kernel!它的功能是传入两个数组 `A` 和 `B`,将 `A + B`(点对点地加)的结果输出到数组 `C` 中。你写道: 189 | 190 | ```cpp 191 | #include 192 | #include 193 | #include 194 | 195 | __global__ void pointwise_add_kernel(int* C, const int* A, const int* B, int n) { 196 | for (int i = 0; i < n; ++i) 197 | C[i] = A[i] + B[i]; 198 | } 199 | 200 | int main() { 201 | const int n = 128; 202 | int* C = new int[n]; 203 | int* A = new int[n]; 204 | int* B = new int[n]; 205 | for (int i = 0; i < n; ++i) { 206 | A[i] = i; 207 | B[i] = i*i; 208 | } 209 | pointwise_add_kernel<<<1, 1>>>(C, A, B, n); 210 | cudaDeviceSynchronize(); // 见下方 Aside 211 | cudaError_t error = cudaGetLastError(); // 检查当前 CUDA 驱动是否返回了任何异常。调用这句话之前记得调用 cudaDeviceSynchronize() 212 | if (error != cudaSuccess) { 213 | printf("CUDA error: %s\n", cudaGetErrorString(error)); 214 | exit(1); 215 | } 216 | for (int i = 0; i < n; ++i) { 217 | assert(C[i] == A[i] + B[i]); 218 | } 219 | return 0; 220 | } 221 | ``` 222 | 223 | *注:记得将文件保存为 `.cu` 类型的,否则 NVCC 不会认为这是一份 CUDA C++ 代码从而导致编译错误,同时记得使用 nvcc 而不是 g++ 来编译。* 224 | 225 | > Aside | 异步执行:CUDA Kernel 是异步执行的,也就是说,所谓的“调用” Kernel 只不过是 CPU 向 GPU 的任务队列里提交了一个任务,随后 CPU 就会继续执行接下来的指令,并不会等待 GPU 将这个 Kernel 执行完。这样设计的目的是:可以同时让 CPU 与 GPU 有活干,同时发掘出二者的潜力。如果想让 CPU 等待 GPU 上的所有 Kernel 均执行完(即,让两个设备同步),请调用 `cudaDeviceSynchronize()`。 226 | {.is-info} 227 | 228 | 可惜的是,程序输出了:`CUDA error: an illegal memory access was encountered`。这是为什么呢?请见下一章:内存管理。 229 | 230 | ## Step 2. Memory Management 231 | 232 | 要想理解为什么上面的程序无法执行,我们需要先学习一下 CUDA 的内存模型。 233 | 234 | 在 CUDA 中,每一个设备都只能访问自己的那一块内存。可以理解为(这个理解并不严谨):整个系统的“内存空间”被分为了两个部分:“内存”与“显存”。CPU 只能访问内存而不能访问显存,GPU 只能访问显存而不能访问内存。上面的例子中,我们就让 GPU 试图访问处于 CPU 内存上的数组 `A`, `B`, `C` 从而导致了 `Illegal memory access` 错误。 235 | 236 | 那么怎么办呢?我们需要先认识几个函数: 237 | 238 | - `cudaMalloc()`: 在显存上申请一块存储空间,类似于 `malloc()`。 239 | - `cudaFree()`:释放一块之前使用 `cudaMalloc()` 申请的存储空间,类似于 `free()`。 240 | - `cudaMemcpy()`:在内存与显存之间拷贝数据,类似于 `memcpy()`。 241 | 242 | 这里有一个更加形象的理解方式:假设我们有两个仓库:A 和 B,分别代表 CPU 内存与 GPU 显存。CPU 只能访问 A 仓库中的数据,GPU 只能访问 B 仓库中的数据。常规的 `malloc()`、`free()` 与 `memcpy()` 都只会影响到 A 仓库, `cudaMalloc()`、`cudaFree()` 的操作对象则是 B 仓库,而 `cudaMemcpy()` 则是在 A 仓库与 B 仓库之间迁移数据。 243 | 244 | > Aside | 对显存进行操作的函数:还有很多函数也可以对显存进行操作,比如 `cudaMemset()`, `cudaMalloc2D()` 等。 245 | {.is-info} 246 | 247 | 所以,我们现在的思路就很清晰了: 248 | 249 | ```cpp 250 | #include 251 | #include 252 | #include 253 | 254 | __global__ void pointwise_add_kernel(int* C, const int* A, const int* B, int n) { 255 | for (int i = 0; i < n; ++i) 256 | C[i] = A[i] + B[i]; 257 | } 258 | 259 | int main() { 260 | const int n = 128; 261 | int* C = new int[n]; 262 | int* A = new int[n]; 263 | int* B = new int[n]; 264 | for (int i = 0; i < n; ++i) { 265 | A[i] = i; 266 | B[i] = i*i; 267 | } 268 | // Create 3 arrays on GPU 269 | int* A_gpu, *B_gpu, *C_gpu; 270 | cudaMalloc(&A_gpu, n * sizeof(int)); 271 | cudaMalloc(&B_gpu, n * sizeof(int)); 272 | cudaMalloc(&C_gpu, n * sizeof(int)); 273 | // Copy the content of A and B to A_gpu and B_gpu, respectively 274 | cudaMemcpy(A_gpu, A, n * sizeof(int), cudaMemcpyHostToDevice); 275 | cudaMemcpy(B_gpu, B, n * sizeof(int), cudaMemcpyHostToDevice); 276 | pointwise_add_kernel<<<1, 1>>>(C_gpu, A_gpu, B_gpu, n); 277 | cudaDeviceSynchronize(); // 见下方 Aside 278 | cudaError_t error = cudaGetLastError(); // 检查当前 CUDA 驱动是否返回了任何异常。调用这句话之前记得调用 cudaDeviceSynchronize() 279 | if (error != cudaSuccess) { 280 | printf("CUDA error: %s\n", cudaGetErrorString(error)); 281 | exit(1); 282 | } 283 | // Copy the result from C_gpu to C 284 | cudaMemcpy(C, C_gpu, n * sizeof(int), cudaMemcpyDeviceToHost); 285 | for (int i = 0; i < n; ++i) { 286 | assert(C[i] == A[i] + B[i]); 287 | } 288 | return 0; 289 | } 290 | ``` 291 | 292 | 这份代码可以正常运行的 293 | 294 | > Aside | CUDA Unified Memory: 有一种东西叫做 Unified Memory,其借助类似操作系统中的 Page Fault 的方式实现了在 CPU 与 GPU 之间无感地共享同一块内存(可以理解为,我有一个指针 `p`,CPU 与 GPU 均能访问 `*p`)。感兴趣的同学可以查看[这篇教程](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/)。 295 | {.is-info} 296 | 297 | ## Step 3. Threads and Blocks 298 | 299 | 如果你运行一下 `./gemm_test 1024 1024 1024`,你会发现,我们刚刚的 [gemm_gpu_1thread.cu](https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code/blob/master/gemm_gpu_1thread.cu) 怎么比 CPU 的版本慢了好几个数量级!说好的“巨大幅度性能提升”呢? 300 | 301 | 而且,之前不是说,GPU 是“大量小核心”么?这也没体现出来呀!我就是写了一个串行版本的函数,怎么就能在“大量小核心“上面,顶多使用一个小核心吧!那性能肯定不行呀! 302 | 303 | 那么,如何用好 GPU 内部的大量小核心呢?这就要涉及到 GPU 内部的三个概念了:Thread, Block 以及 Grid: 304 | 305 | - Thread 是最基本的执行单位,**每一个 Thread 都会把你写的 CUDA Kernel 从头到尾完整地执行一遍**。 306 | - 每一个 Block 中包含若干个 Thread,每一个 Thread 都会有一个 `threadIdx`,代表这个 Thread 在它所在的 Block 中的 id。可以使用 `blockDim` 来获取 Block 中有多少个 Thread。 307 | - 每一个 Grid 包含若干个 Block,每一个 Thread 也有一个 `blockIdx`,代表这个 Thread 所在的 Block 在 Grid 中的 id。可以使用 `gridDim` 来获取 Grid 中有多少个 Block。每一次启动 CUDA Kernel 时都会生成一个 Grid(某种意义上可以理解为一个“执行上下文”。 308 | 309 | 三者的关系看上去大概就是这样的: 310 | 311 | ![kernel-execution-on-gpu-1.png](/images/hpc/kernel-execution-on-gpu-1.png) 312 | 313 | *Image source: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/* 314 | 315 | 在启动 CUDA Kernel 时,`<<<>>>` 中的第一个数字是每一个 Grid 中的 Block 数量,第二个数字是每一个 Block 中的 Thread 数量(每一个 Block 中含有的 Thread 数量是相等的)。比如: 316 | 317 | ```cpp 318 | #include 319 | #include 320 | 321 | __global__ void print_grid_block_info_kernel() { 322 | printf("Block id: %d. Number of blocks in one grid: %d. " 323 | "Thread id: %d. Number of threads in one block: %d\n", 324 | blockIdx.x, gridDim.x, threadIdx.x, blockDim.x); 325 | } 326 | 327 | int main() { 328 | const int GRID_SIZE = 4; 329 | const int BLOCK_SIZE = 3; 330 | print_grid_block_info_kernel<<>>(); 331 | cudaDeviceSynchronize(); 332 | cudaError_t error = cudaGetLastError(); // 检查当前 CUDA 驱动是否返回了任何异常。调用这句话之前记得调用 cudaDeviceSynchronize() 333 | if (error != cudaSuccess) { 334 | printf("CUDA error: %s\n", cudaGetErrorString(error)); 335 | exit(1); 336 | } 337 | return 0; 338 | } 339 | ``` 340 | 341 | 由于每一个 CUDA Kernel 都会被每个 Thread 完整地执行一遍,所以我们称 CUDA 是 SPMD (Single Program Multiple Data) Style 的。 342 | 343 | 让我们再完善一下我们的向量点对点加法程序: 344 | 345 | ```cpp 346 | #include 347 | #include 348 | #include 349 | 350 | __global__ void pointwise_add_kernel(int* C, const int* A, const int* B, int n) { 351 | // 别忘了,每一个 Thread 把整个 CUDA Kernel 都完整地执行一遍 352 | for (int i = blockIdx.x*blockDim.x+threadIdx.x; i < n; i += blockDim.x*gridDim.x) 353 | C[i] = A[i] + B[i]; 354 | } 355 | 356 | int main() { 357 | const int n = 128; 358 | const int BLOCK_DIM = 4; 359 | const int GRID_DIM = 3; 360 | int* C = new int[n]; 361 | int* A = new int[n]; 362 | int* B = new int[n]; 363 | for (int i = 0; i < n; ++i) { 364 | A[i] = i; 365 | B[i] = i*i; 366 | } 367 | // Create 3 arrays on GPU 368 | int* A_gpu, *B_gpu, *C_gpu; 369 | cudaMalloc(&A_gpu, n * sizeof(int)); 370 | cudaMalloc(&B_gpu, n * sizeof(int)); 371 | cudaMalloc(&C_gpu, n * sizeof(int)); 372 | // Copy the content of A and B to A_gpu and B_gpu, respectively 373 | cudaMemcpy(A_gpu, A, n * sizeof(int), cudaMemcpyHostToDevice); 374 | cudaMemcpy(B_gpu, B, n * sizeof(int), cudaMemcpyHostToDevice); 375 | pointwise_add_kernel<<>>(C_gpu, A_gpu, B_gpu, n); 376 | cudaDeviceSynchronize(); // 见下方 Aside 377 | cudaError_t error = cudaGetLastError(); // 检查当前 CUDA 驱动是否返回了任何异常。调用这句话之前记得调用 cudaDeviceSynchronize() 378 | if (error != cudaSuccess) { 379 | printf("CUDA error: %s\n", cudaGetErrorString(error)); 380 | exit(1); 381 | } 382 | // Copy the result from C_gpu to C 383 | cudaMemcpy(C, C_gpu, n * sizeof(int), cudaMemcpyDeviceToHost); 384 | for (int i = 0; i < n; ++i) { 385 | assert(C[i] == A[i] + B[i]); 386 | } 387 | return 0; 388 | } 389 | ``` 390 | 391 | 接下来,请阅读并理解 [gemm_gpu_mult_thread.cu](https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code/blob/master/gemm_gpu_mult_thread.cu) 与 [gemm_gpu_mult_block.cu](https://github.com/interestingLSY/CUDA-From-Correctness-To-Performance-Code/blob/master/gemm_gpu_mult_block.cu)。请暂时忽略代码中的 `__restrict__`。 392 | 393 | 在 Ryzen 7700X CPU + RTX 4090 GPU 上,各个 GEMM 的时限的耗时如下: 394 | 395 | ```plain 396 | Results: 397 | cpu_naive 348434.62 us 398 | cpu_simd 133023.88 us 399 | gpu_1thread 20276114.00 us 400 | gpu_mult_thread 410566.25 us 401 | gpu_mult_block 3988.38 us 402 | ``` 403 | 404 | 可以看到,在加入了多 Thread 与多 Block 之后,性能取得了明显的提升。 405 | 406 | > Aside | 高维 Grid & Block:在上面的例子中,Grid 与 Block 都是一维的,但其实 CUDA 支持高达三维的 Grid & Block 形状。 407 | {.is-info} 408 | 409 | ## Takeaway 410 | 411 | 那么,你现在已经可以写出性能基本过关的 CUDA 程序啦!我们再来回顾一下基本知识: 412 | 413 | - CUDA 的设计思想大致是:向显卡提交一个又一个任务,每一个任务都形如“给定一个函数,与调用它的参数,请在显卡上运行这个函数”。我们一般称这种“在显卡上运行的函数”为 CUDA Kernel。 414 | - 可以使用 `__global__`, `__host__` 和 `__device__` 来修饰函数。如果想写一个 CUDA Kernel(能被 CPU 上运行的其他函数,并在 GPU 上执行的函数),那么应当使用 `__global__`。 415 | - 启动 CUDA Kernel 时,请在函数名和参数列表之间加上 `<<<每个 Grid 中有多少 Block, 每个 Block 中有多少 Thread>>>`。 416 | - 启动 CUDA Kernel 的时候会创建一个 Grid。这个 Grid 里包含若干 Block,每个 Block 里包含若干 Thread。 417 | 418 | # Part 2. Performance 419 | 420 | 在上一章中,我们学习了如何写出正确的 CUDA Kernel,那么现在我们来学学如何利用好 GPU 的底层架构,优化 CUDA Kernel 的性能。 421 | 422 | ## 0. 算存比 423 | 424 | 想要优化 GPU 的性能,我们首先要知道“算存比”的概念。 425 | 426 | 在经典的冯诺依曼架构下,ALU (Arithmetic Logic Unit,计算逻辑单元,可以简单理解为加法器、乘法器等) 要从内存中取操作数,进行对应的计算(如乘法),并写回内存。所以,计算速度会受到两个因素的限制:ALU 进行计算的速度,与内存的存取速度。如果一个程序的运行速度瓶颈在于前者,那么称其为 Compute-bound 的;如果瓶颈在于后者,那么称其为 Memory-bound 的。 427 | 428 | 由于 CPU 中运算单元较少,且 CPU 具有多级缓存,所以空间连续性、时间连续性较好的程序在 CPU 上一般是 Compute-bound 的。而 GPU 则恰恰相反:GPU 的核心的规模一般很大,比如 RTX 4090 可以在一秒内做 82.58T 次 float16 运算(暂不考虑 Tensor core),但其内存带宽只有 1TB/s,每秒只能传输 0.5T 个 float16。这便导致 GPU 上的操作更可能会受到内存带宽的限制,成为 Memory-bound。 429 | 430 | 如何估测一个 CUDA Kernel 是 Compute-bound 还是 Memory-bound 呢?我们可以计算它的“算存比”,也即,$计算次数/访存次数$,并将其与 GPU 的 $每秒能做的运算次数/每秒能做的访存次数$ 做比较(这里其实不太严谨,仅能用来做粗略估计,严谨的计算还要考虑到 FMA、缓存、显存带宽利用率等等因素)。 431 | 432 | 比如,对于上面的 `pointwise_add_kernel`,其需要访问 $3N$ 次内存,同时做 $N$ 次加法,所以其存算比为 $N/3N = 1/3$,其远小于 $82.58T/0.5T = 165.16$,所以其为 Memory-bound。 433 | 434 | 我们的优化思路大体是:如果一个 Kernel 是 Memory-bound 的,那么就优化它的访存次数(哪怕这样可能需要多进行一些计算),反之则要减少其计算次数。一般来说,Compute-bound 的 Kernel 不太常见(毕竟算存比得过百才能达到 Compute-bound)(常见的 Compute-bound 的 Kernel 可能只有矩阵乘法与卷积核比较大的卷积),所以下面我们主要关注如何优化访存。 435 | 436 | > Aside | Fused Multiply-Add (FMA):现在的 Nvidia GPU 可以在 1 个时钟周期内计算 `a*b+c`, (`a`, `b` 和 `c` 均为浮点数),而不是先花一个周期计算加法,再花一个周期计算乘法。这称为 Fused multiply-add (FMA)。 437 | 438 | ## 1. `__restrict__` 439 | 440 | 大家还记得什么是 Pointer aliasing 嘛?简单来说,下面两段代码并不是等价的: 441 | 442 | ```cpp 443 | void f1(int* x, int* y) { 444 | *x += *y; 445 | *x += *y; 446 | } 447 | ``` 448 | 449 | ```cpp 450 | void f2(int* x, int* y) { 451 | *x += 2*(*y); 452 | } 453 | ``` 454 | 455 | 这是因为,`x` 和 `y` 两个指针可能指向相同的内存。考虑 `f(x, x)`,第一段代码将把 `*x` 变为 `4(*x)`,而第二段代码则会把 `*x` 变为 `3(*x)`。 456 | 457 | Pointer aliasing 可能会抑制编译器做出某些优化。比如在上面的代码中,`f1()` 需要 5 次访存而 `f2()` 仅需三次,后者更优。但由于编译器并不能假设 `x` 和 `y` ,它不敢做这个优化。 458 | 459 | 所以,我们需要“显式地”告诉编译器,两个指针不会指向相同的内存地址(准确来说,应该是“改变一个指针指向的地址的数据,不会影响到通过其他指针读取的数据”),从而让编译器“放心地”做出优化。`nvcc` 支持一个关键字,叫做 `__restrict__`,加上它,编译器就可以放心地把指针指向的值存在寄存器里,而不是一次又一次地访存,进而提高了性能。 460 | 461 | 我们可以对比一下示例代码中的 `gemm_gpu_mult_block_no_restrict.cu` 与 `gemm_gpu_mult_block.cu` 的性能。在 4090 上,前者平均耗时 40420.75,后者平均耗时 3988.38。可以看出,性能提升幅度不容小觑。 462 | 463 | 为了验证性能下降确实是由于没有了 `__restrict__` 关键字后的额外访存带来的,我们可以对比 `gemm_gpu_mult_block.cu` 与 `gemm_gpu_mult_block_no_restrict_reg.cu` 的性能。后者虽然没有使用 `__restrict__` 关键字,但它把中间的累加结果存在了变量中,而不是每一次都写回 C 数组。在 4090 上,二者的性能非常相似。这说明,在缺少 `__restrict__` 关键字的时候,代码需要进行许多不必要的访存,进而拖慢了速度。 464 | 465 | ## 2. Memory Coalescing 466 | 467 | 在学习 Memory coalescing 之前,我们需要先了解一下 GPU 内部的调度方式。 468 | 469 | 之前我们说过,Grid 里包含若干 Thread block,每个 Thread block 则又包含若干 Thread,那么这些 Thread 是如何被调度的呢?它们被按编号分成了若干组,每一组中有 32 个 Thread(即,线程 0 ~ 31 为第一组,32 ~ 63 为第二组,依次类推),这样的“组”便被叫做 Warp。 470 | 471 | GPU 的调度是以 Warp 作为基本单位的。每个时钟周期内,同一个 Warp 中的所有线程都会执行相同的指令。 472 | 473 | 那么访存呢?难道 Warp 中的 32 个 Thread 同时访存的话,GPU 核心会向显存发起 32 次请求嘛?显然不会。GPU 会把这些请求打包成尽可能少的 Transaction(可以把每一个 Transaction 理解为 GPU 核心向显存发起的一次访存操作),这个过程就叫做 Memory coalescing。Transaction 需要满足: 474 | 475 | - 长度为 32 个 Byte 476 | - 开始地址是 32 的倍数 477 | 478 | 也即:如果一个 Warp 中的第 i 个 Thread 要访问地址为 $4i \sim 4i+3$ 的内存,那么一共需要 4 个 Transaction 才能读完所有的数据;如果一个 Warp 中的第 i 个 Thread 要访问地址为 $4i+1 \sim 4i+4$ 的内存,那么需要 5 个 Transaction 才能读取所有的数据;如果第 i 个 Warp 要访问地址为 $32i \sim 32i+3$ 的内存,那么就需要 32 次 Transaction 才能完成读取了。 479 | 480 | 然而,内存带宽是有上限的,且每一个 Transaction 的大小都是 32 Byte,这注定了每一秒 GPU 核心可以发起的 Transaction 数量是有上限的。对于上述的最后一种情况,由于每一个 Transaction 中的 32 Byte 只有 4 Byte 是有用的,此时内存带宽的利用率仅有 $1/8$。 481 | 482 | 接下来请阅读 [CUDA Best Practices](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#coalesced-access-to-global-memory),了解 Memory coalescing 在一个具体的例子中的优化效果。我实在是写不动了,妈的事情太多了压根干不完,救命。 483 | 484 | 总之,我们需要尽量保证同一个 Warp 中每一个 Thread 的访存是 coalesced 的,以充分利用内存带宽。 485 | 486 | ## 3. Shared Memory 487 | 488 | 在学习 Shared memory 之前,我们需要先了解一下 CUDA 的内存模型: 489 | 490 | CUDA 中大致有这几种内存: 491 | 492 | - Global Memory:俗称显存,位于 GPU 核心外部,很大(比如 A100 有 80GB),但是带宽很有限 493 | - L2 Cache:位于 GPU 核心内部,是显存的缓存,程序不能直接使用 494 | - Register:寄存器,位于 GPU 核心内部,Thread 可以直接调用 495 | - Shared memory:位于 GPU 核心内部,每个 Thread block 中的所有 Thread 共用同一块 Shared memory(因此,Shared memory 可以用来在同一个 Thread block 的不同 Thread 之间共享数据),并且带宽极高(因此,Shared memory 可以用来优化性能)。 496 | 497 | 正如上文所说,Share memory 既可以用来在同一个 Thread block 的不同 Thread 之间共享数据(最常见的用法是 Reduction),也可以用来优化访存性能。我们现在主要关注后者。 498 | 499 | 我们还是以矩阵乘法为例。在上面的 `gemm_gpu_mult_block.cu` 中,为了计算大小分别为 $n \times k$ 与 $k \times m$ 的两个矩阵乘法,我们一共访问了大约 $2nmk$ 次内存。这十分不合算,因为三个矩阵加起来也就只有 $nk + km + nm$ 个元素。 500 | 501 | 我们尝试使用 Shared memory 来优化矩阵乘法。具体的,我们使用一种叫做 Tiling 的技术。接下来请阅读[这篇文章](https://penny-xu.github.io/blog/tiled-matrix-multiplication)(里面有很多好看又形象的动图)。 502 | 503 | 在阅读上面那篇文章之后,请阅读示例代码中的 `gemm_gpu_tiling.cu`,看看我如何实现 Tiling 版本的矩阵乘法。在 4090 上,`gemm_gpu_mult_block` 耗时 3988.38 us,`gemm_gpu_tiling` 耗时 311.38 us,性能提升约 10 倍。 504 | 505 | > Aside | CUDA Memory Hierarchy: ![memory-hierarchy-in-gpus-1.png](/images/hpc/memory-hierarchy-in-gpus-1.png) 506 | > *Image source: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/* 507 | {.is-info} 508 | 509 | > Aside | Reduction: 对 Reduction 操作的优化比较感兴趣的同学可以阅读 [这篇文章](https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf)。 510 | {.is-info} 511 | 512 | ## 4. Profiling Tools 513 | 514 | 在优化 CUDA Kernel 的时候,除了依照经验与惯用套路(比如 Memory coalescing),我们也可以使用专业的 Profiling 工具来测试一个 Kernel 或者一个程序的性能瓶颈。常用的 Profiling 工具包括: 515 | 516 | - Nvidia Nsight System: 它可以对整个应用程序进行 Profile,可以得到各个 Kernel 的耗时,以及究竟是 CPU 还是 GPU 拖慢了整体的执行速度。 517 | - Nvidia Nsight Compute: 它可以对单个 CUDA Kernel 进行 Profiling,进而得到该 CUDA Kernel 的瓶颈所在。它会提供许多的详细信息(比如,内存带宽占用率、CUDA Core 活跃时间比、活跃的 SM 比例等等)来帮助你更加细致地优化 CUDA Kernel。 518 | 519 | ## Takeaway 520 | 521 | 总结一下,我们的优化技巧包括: 522 | 523 | - 使用 `__restrict__` 让编译器放心地优化指针访存 524 | - 想办法让同一个 Warp 中的线程的访存 Pattern 尽可能连续,以利用 Memory coalescing 525 | - 使用 Shared memory 526 | - 使用专业的 Profiling Tool 527 | --------------------------------------------------------------------------------