├── LICENSE ├── README.md ├── matrix_cpu └── matrix_cpu │ └── main.cpp ├── matrix_gpu └── matrix_gpu │ └── main.cu ├── matrix_wmma └── matrix_wmma │ └── main.cu └── project └── project └── kernel.cu /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zong-Sheng Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Matrix Multiply-Accumulate(MMA) on GPU 2 | ### Sample code for undergrads on the Capstone Project Course of Hallym university in autumn semester 2018. 3 | **Purpose:** To implement and measure performance of Matrix Multiply-Accumulate(like D = A * B + C) on CPU, GPU (with/without Tensor Cores), respectively. 4 | 5 | **Note** that this repository only contains the **less performant** version of implementations. It is designed for demonstration purposes only to show how your project should be done. 6 | 7 | #### matrix_cpu 8 | includes sample code of MMA with a single thread on CPU 9 | 10 | #### matrix_gpu 11 | includes sample code of MMA on GPU without Tensor Cores by CUDA API 12 | 13 | #### matrix_wmma 14 | includes sample code of MMA on GPU with Tensor Cores by WMMA API 15 | 16 | #### project 17 | To show how your project organized the algorithm implementation, performance metrics and result verification 18 | 19 | --- 20 | 21 | ### Tips for compiling *.cu 22 | $ nvcc -o main main.cu -arch sm_75 23 | 24 | **Tensor Core is only supported by CUDA compute capability 7.0 and above** 25 | 26 | 7.0 <=> Volta (Titian V / Quadro GV100) 27 | 28 | 7.5 <=> Turing (RTX 2080/ RTX 2080 Ti / Quadro RTX 6000) 29 | 30 | --- 31 | 32 | ### References 33 | - Programming Tensor Cores in CUDA 9 34 | - https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/ 35 | - How to Implement Performance Metrics in CUDA C/C++ 36 | - https://devblogs.nvidia.com/how-implement-performance-metrics-cuda-cc/ 37 | - NVIDIA Turing Architecture Whitepaper 38 | - https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf 39 | - NVIDIA Volta Architecture Whitepaper 40 | - http://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf 41 | - Tensorコアを使ってみた 42 | - http://proc-cpuinfo.fixstars.com/2018/10/tensorcore/ 43 | - CUTLASS: Fast Linear Algebra in CUDA C++ 44 | - https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ 45 | 46 | -------------------------------------------------------------------------------- /matrix_cpu/matrix_cpu/main.cpp: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////// 2 | // Calcuating Matrix A*B+C (CPU Version) 3 | // Created by Wang Zong-Sheng 4 | // 2018/10/18 5 | 6 | #include 7 | #include 8 | 9 | #define A_ROW 3 10 | #define A_COL 2 11 | #define B_ROW 2 12 | #define B_COL 3 13 | #define C_ROW 2 14 | #define C_COL 2 15 | 16 | using namespace std; 17 | 18 | template 19 | void matrix_add(const T *A, const T *B, unsigned int row, unsigned int col, T *R) { 20 | for (unsigned int c = 0; c < col; c++) { 21 | for (unsigned int r = 0; r < row; r++) { 22 | unsigned int i = c*row + r; 23 | R[i] = A[i] + B[i]; 24 | } 25 | } 26 | } 27 | 28 | template 29 | void matrix_mul(const T *A, unsigned int a_row, unsigned int a_col, const T *B, unsigned int b_row, unsigned int b_col, T *R) { 30 | memset(R, 0, a_col*b_row*sizeof(T)); 31 | for (unsigned int c = 0; c < a_col; c++) { 32 | for (unsigned int r = 0; r < b_row; r++) { 33 | unsigned int index = c * b_row + r; 34 | for (unsigned int i = 0; i < a_row; i++) { 35 | R[index] += A[c*a_row + i] * B[i*b_row + r]; 36 | } 37 | } 38 | } 39 | } 40 | 41 | template 42 | void print_matrix(T *M, unsigned int row, unsigned int col) { 43 | for (unsigned int c = 0; c < col; c++) { 44 | for (unsigned int r = 0; r < row; r++) { 45 | cout << M[c*row + r] << ", "; 46 | } 47 | cout << endl; 48 | } 49 | } 50 | 51 | int main(void) { 52 | //clock_t start_timer = clock(); 53 | const int A[A_ROW*A_COL] = { 1, 0, -3, 54 | -2, 4, 1}; 55 | const int B[B_ROW*B_COL] = { 2, -1, 56 | 3, 0, 57 | -5, 2}; 58 | const int C[C_ROW*C_COL] = { 3, -1, 59 | -2, 2}; 60 | int AB[A_COL*B_ROW], R[C_ROW*C_COL]; 61 | 62 | matrix_mul(A, A_ROW, A_COL, B, B_ROW, B_COL, AB); 63 | matrix_add(AB, C, C_ROW, C_COL, R); 64 | 65 | //clock_t stop_timer = clock(); 66 | //double duration = double(stop_timer - start_timer); 67 | 68 | // for printing results 69 | cout << "A = " << endl; 70 | print_matrix(A, A_ROW, A_COL); 71 | 72 | cout << endl << "B = " << endl; 73 | print_matrix(B, B_ROW, B_COL); 74 | 75 | cout << endl << "C = " << endl; 76 | print_matrix(C, C_ROW, C_COL); 77 | 78 | cout << endl << "Result:" << endl; 79 | cout << "A x B = " << endl; 80 | print_matrix(AB, B_ROW, A_COL); 81 | 82 | cout << endl << "A x B + C = " << endl; 83 | print_matrix(R, C_ROW, C_COL); 84 | 85 | //cout << "elapsed time = " << duration << "s" << endl; 86 | return 0; 87 | } -------------------------------------------------------------------------------- /matrix_gpu/matrix_gpu/main.cu: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////// 2 | // Calcuating Matrix A*B+C (CUDA Version) 3 | // Created by Wang Zong-Sheng 4 | // 2018/10/18 5 | #include 6 | using namespace std; 7 | 8 | #include "cuda_runtime.h" 9 | #include "device_launch_parameters.h" 10 | 11 | #define A_ROW 3 12 | #define A_COL 2 13 | #define B_ROW 2 14 | #define B_COL 3 15 | #define C_ROW 2 16 | #define C_COL 2 17 | #define MP_NUM 15 18 | #define CORES_PER_MP 192 19 | 20 | template 21 | __global__ void cuda_matrix_mul(const T *A, const T *B, T *R) 22 | { 23 | int bId = blockIdx.y * gridDim.x + blockIdx.x; 24 | T sum = A[blockIdx.y* A_ROW + threadIdx.x] * B[threadIdx.x * B_ROW + blockIdx.x]; 25 | __syncthreads(); 26 | //printf("Thread %d In block %d : R = %d, sum = %d\n", threadIdx.x, bId, temp, sum); 27 | atomicAdd(&R[bId], sum); 28 | 29 | } 30 | 31 | //template 32 | //__global__ void cuda_matrix_mul(const T *A, const T *B, T *R) 33 | //{ 34 | // int bId = blockIdx.y * gridDim.x + blockIdx.x; 35 | // int sum = 0; 36 | // for(int i=0; i 43 | __global__ void cuda_matrix_add(const T *A, const T *B, T *R) 44 | { 45 | int ix = blockIdx.x * blockDim.x + threadIdx.x; 46 | int iy = blockIdx.y * blockDim.y + threadIdx.y; 47 | int idx = iy * blockDim.x + ix; 48 | R[idx] = A[idx] + B[idx]; 49 | } 50 | 51 | 52 | // using CUDA to implement AxB+C 53 | template 54 | cudaError_t matrix_mul_add_cuda(const T *A, unsigned int a_row, unsigned int a_col, 55 | const T *B, unsigned int b_row, unsigned int b_col, 56 | const T *C, unsigned int c_row, unsigned int c_col, 57 | T *R, T *AB) 58 | { 59 | T *dev_a = 0; 60 | T *dev_b = 0; 61 | T *dev_c = 0; 62 | T *dev_ab = 0; 63 | T *dev_r = 0; 64 | cudaError_t cudaStatus; 65 | 66 | // Choose which GPU to run on, change this on a multi-GPU system. 67 | cudaStatus = cudaSetDevice(0); 68 | if (cudaStatus != cudaSuccess) { 69 | printf("cudaSetDevice failed! Do you have a CUDA-capable GPU installed?"); 70 | goto Error; 71 | } 72 | 73 | // Allocate GPU buffers for matrics 74 | cudaStatus = cudaMalloc((void**)&dev_a, a_row * a_col * sizeof(T)); 75 | if (cudaStatus != cudaSuccess) { 76 | printf("cudaMalloc failed!"); 77 | goto Error; 78 | } 79 | 80 | cudaStatus = cudaMalloc((void**)&dev_b, b_row * b_col * sizeof(T)); 81 | if (cudaStatus != cudaSuccess) { 82 | printf("cudaMalloc failed!"); 83 | goto Error; 84 | } 85 | 86 | cudaStatus = cudaMalloc((void**)&dev_c, c_row * c_col * sizeof(T)); 87 | if (cudaStatus != cudaSuccess) { 88 | printf("cudaMalloc failed!"); 89 | goto Error; 90 | } 91 | 92 | cudaStatus = cudaMalloc((void**)&dev_ab, b_row * a_col * sizeof(T)); 93 | if (cudaStatus != cudaSuccess) { 94 | printf("cudaMalloc failed!"); 95 | goto Error; 96 | } 97 | 98 | cudaStatus = cudaMalloc((void**)&dev_r, c_row * c_col * sizeof(T)); 99 | if (cudaStatus != cudaSuccess) { 100 | printf("cudaMalloc failed!"); 101 | goto Error; 102 | } 103 | 104 | 105 | // Copy input matrics from host memory to GPU buffers. 106 | cudaStatus = cudaMemcpy(dev_a, A, a_row * a_col * sizeof(T), cudaMemcpyHostToDevice); 107 | if (cudaStatus != cudaSuccess) { 108 | printf("cudaMemcpy failed!"); 109 | goto Error; 110 | } 111 | 112 | cudaStatus = cudaMemcpy(dev_b, B, b_row * b_col * sizeof(T), cudaMemcpyHostToDevice); 113 | if (cudaStatus != cudaSuccess) { 114 | printf("cudaMemcpy failed!"); 115 | goto Error; 116 | } 117 | 118 | cudaStatus = cudaMemcpy(dev_c, C, c_row * c_col * sizeof(T), cudaMemcpyHostToDevice); 119 | if (cudaStatus != cudaSuccess) { 120 | printf("cudaMemcpy failed!"); 121 | goto Error; 122 | } 123 | 124 | // Launch a kernel on the GPU 125 | // In our case, K40c GPU has 15MP, and 192cores per MP 126 | dim3 grids(B_ROW, A_COL); 127 | //cuda_matrix_mul <<>> (dev_a, dev_b, dev_ab); 128 | cuda_matrix_mul <<>> (dev_a, dev_b, dev_ab); 129 | 130 | cudaDeviceSynchronize(); 131 | 132 | dim3 threads(C_ROW, C_COL); 133 | cuda_matrix_add <<<1, threads>>> (dev_ab, dev_c, dev_r); 134 | 135 | 136 | // Check for any errors launching the kernel 137 | cudaStatus = cudaGetLastError(); 138 | if (cudaStatus != cudaSuccess) { 139 | printf("addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus)); 140 | goto Error; 141 | } 142 | 143 | // cudaDeviceSynchronize waits for the kernel to finish, and returns 144 | // any errors encountered during the launch. 145 | cudaStatus = cudaDeviceSynchronize(); 146 | if (cudaStatus != cudaSuccess) { 147 | printf("cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus); 148 | goto Error; 149 | } 150 | 151 | // Copy output vector from GPU buffer to host memory. 152 | cudaStatus = cudaMemcpy(AB, dev_ab, b_row * a_col * sizeof(T), cudaMemcpyDeviceToHost); 153 | if (cudaStatus != cudaSuccess) { 154 | printf("cudaMemcpy failed!"); 155 | goto Error; 156 | } 157 | 158 | cudaStatus = cudaMemcpy(R, dev_r, c_row * c_col * sizeof(T), cudaMemcpyDeviceToHost); 159 | if (cudaStatus != cudaSuccess) { 160 | printf("cudaMemcpy failed!"); 161 | goto Error; 162 | } 163 | 164 | Error: 165 | cudaFree(dev_a); 166 | cudaFree(dev_a); 167 | cudaFree(dev_c); 168 | cudaFree(dev_ab); 169 | cudaFree(dev_r); 170 | 171 | return cudaStatus; 172 | } 173 | 174 | template 175 | void print_matrix(T *M, unsigned int row, unsigned int col) { 176 | for (unsigned int c = 0; c < col; c++) { 177 | for (unsigned int r = 0; r < row; r++) { 178 | cout << M[c*row + r] << ", "; 179 | } 180 | cout << endl; 181 | } 182 | } 183 | 184 | int main() 185 | { 186 | const int A[A_ROW*A_COL] = { 1, 0, -3, 187 | -2, 4, 1 }; 188 | const int B[B_ROW*B_COL] = { 2, -1, 189 | 3, 0, 190 | -5, 2 }; 191 | const int C[C_ROW*C_COL] = { 3, -1, 192 | -2, 2 }; 193 | int AB[A_COL*B_ROW]; 194 | int R[C_ROW*C_COL]; 195 | 196 | cudaError_t cudaStatus; 197 | cudaStatus = matrix_mul_add_cuda(A, A_ROW, A_COL, B, B_ROW, B_COL, C, C_ROW, C_COL, R, AB); 198 | 199 | 200 | cudaStatus = cudaDeviceReset(); 201 | if (cudaStatus != cudaSuccess) { 202 | //cout << "cudaDeviceReset failed!" << endl; 203 | return 1; 204 | } 205 | 206 | // for printing results 207 | cout << "A = " << endl; 208 | print_matrix(A, A_ROW, A_COL); 209 | 210 | cout << endl << "B = " << endl; 211 | print_matrix(B, B_ROW, B_COL); 212 | 213 | cout << endl << "C = " << endl; 214 | print_matrix(C, C_ROW, C_COL); 215 | 216 | cout << endl << "Result:" << endl; 217 | cout << "A x B = " << endl; 218 | print_matrix(AB, B_ROW, A_COL); 219 | 220 | cout << endl << "A x B + C = " << endl; 221 | print_matrix(R, C_ROW, C_COL); 222 | 223 | 224 | return 0; 225 | } -------------------------------------------------------------------------------- /matrix_wmma/matrix_wmma/main.cu: -------------------------------------------------------------------------------- 1 | ////////////////////////////////////////////////////////////////////// 2 | // A simple example to show how CUDA WMMA API works with Tensor Cores 3 | // Created by Zong-Sheng Wang @ 2018/11/25 4 | // Performance Tips: 5 | // To minimize bank conflicts, you should try to shift row or 6 | // column of matrics in shared memory 7 | // cmd: 8 | // $ nvcc -o main main.cu -arch sm_75 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "cuda_runtime.h" 16 | #include "device_launch_parameters.h" 17 | 18 | #define WARP_SIZE 32 19 | 20 | // MMA matrix tile dimensions. 21 | #define M 16 22 | #define N 16 23 | #define K 16 24 | 25 | // GEMM configuration. 26 | #define M_TILES 256 27 | #define N_TILES 256 28 | #define K_TILES 256 29 | 30 | #define M_TOTAL (M * M_TILES) 31 | #define N_TOTAL (N * N_TILES) 32 | #define K_TOTAL (K * K_TILES) 33 | 34 | 35 | //__global__ void WMMAINT8() 36 | using namespace nvcuda; 37 | 38 | __host__ void InitMatrix(half *A, half *B, float *C) 39 | { 40 | for (int i = 0; i < M_TOTAL*K_TOTAL; i++) 41 | A[i] = __float2half(rand() % 1000 / 1000.0f); 42 | for (int i = 0; i < K_TOTAL*N_TOTAL; i++) 43 | B[i] = __float2half(rand() % 1000 / 1000.0f); 44 | for (int i = 0; i < M_TOTAL*N_TOTAL; i++) 45 | C[i] = rand() % 1000 / 1000.0f; 46 | } 47 | 48 | 49 | 50 | __global__ void WMMAF16TensorCore(half *A, half *B, float *C, float *D) 51 | { 52 | int ix = (blockIdx.x * blockDim.x + threadIdx.x)/WARP_SIZE; 53 | int iy = (blockIdx.y * blockDim.y + threadIdx.y); 54 | 55 | wmma::fragment a_frag; 56 | wmma::fragment b_frag; 57 | wmma::fragment ab_frag; 58 | wmma::fragment c_frag; 59 | 60 | wmma::fill_fragment(ab_frag, 0.0f); 61 | 62 | // AB = A*B 63 | int a_col, a_row, b_col, b_row, c_col, c_row; 64 | a_row = ix * M; 65 | b_row = iy * N; 66 | for (int k=0; k>>(A, B, C, D); 112 | cuda_status = cudaDeviceSynchronize(); 113 | 114 | cudaEventRecord(stop); 115 | cudaEventSynchronize(stop); 116 | 117 | float milliseconds = 0; 118 | cudaEventElapsedTime(&milliseconds, start, stop); 119 | 120 | // for Performance Metrics 121 | printf("[+] GPU(with Tensor Cores) Elapsed Time: %f ms\n", milliseconds); 122 | // references from https://devblogs.nvidia.com/how-implement-performance-metrics-cuda-cc/ 123 | printf("[+] TFLOPS: %.2f\n", ((double)M_TOTAL * N_TOTAL* K_TOTAL * 2) / milliseconds / 1e9); 124 | cudaEventDestroy(start); 125 | cudaEventDestroy(stop); 126 | 127 | return cuda_status; 128 | } 129 | 130 | 131 | int main() 132 | { 133 | cudaError_t cuda_status; 134 | cuda_status = cudaSetDevice(0); 135 | if (cuda_status != cudaSuccess) { 136 | printf("cudaSetDevice failed! "); 137 | return 1; 138 | } 139 | 140 | 141 | // Matrix on device 142 | half *A; 143 | half *B; 144 | float *C; 145 | float *D; 146 | 147 | // CUDA Unified Memory 148 | cudaMallocManaged((void **)&A, sizeof(half) * M_TOTAL * K_TOTAL); 149 | cudaMallocManaged((void **)&B, sizeof(half) * K_TOTAL * N_TOTAL); 150 | cudaMallocManaged((void **)&C, sizeof(float) * M_TOTAL * N_TOTAL); 151 | cudaMallocManaged((void **)&D, sizeof(float) * M_TOTAL * N_TOTAL); 152 | 153 | // Init matrix A B C on host 154 | //InitHostMatrix(host_A, host_B, host_C); 155 | printf("[*] Initializing Matrix...\n"); 156 | InitMatrix(A, B, C); 157 | printf("[+] A: %d x %d\n", M_TOTAL, K_TOTAL); 158 | printf("[+] B: %d x %d\n", K_TOTAL, N_TOTAL); 159 | printf("[+] C: %d x %d\n", M_TOTAL, N_TOTAL); 160 | 161 | // computing gemm using tensor core 162 | printf("[*] Computing D = A * B +C with Tensor Cores...\n"); 163 | // D = A * B +C, D holds the result after ret 164 | cuda_status = CalcWMMA(A, B, C, D); 165 | 166 | cuda_status = cudaDeviceReset(); 167 | if (cuda_status != cudaSuccess) { 168 | printf("cudaDeviceReset failed! "); 169 | return 1; 170 | } 171 | // Todo: Add a function to verify the result by using the result of CPU version implementation. 172 | 173 | cudaFree(A); 174 | cudaFree(B); 175 | cudaFree(C); 176 | cudaFree(D); 177 | 178 | return 0; 179 | } 180 | -------------------------------------------------------------------------------- /project/project/kernel.cu: -------------------------------------------------------------------------------- 1 | ////////////////////////////////////////////////////////////////////// 2 | // Sample code to show how your project works 3 | // Created by Zong-Sheng Wang @ 2018/11/25 4 | 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "cuda_runtime.h" 14 | #include "device_launch_parameters.h" 15 | 16 | #define WARP_SIZE 32 17 | 18 | // MMA matrix tile dimensions. 19 | #define M 16 20 | #define N 16 21 | #define K 16 22 | 23 | // GEMM configuration. 24 | #define M_TILES 256 25 | #define N_TILES 256 26 | #define K_TILES 256 27 | 28 | #define M_TOTAL (M * M_TILES) 29 | #define N_TOTAL (N * N_TILES) 30 | #define K_TOTAL (K * K_TILES) 31 | 32 | 33 | //__global__ void WMMAINT8() 34 | using namespace nvcuda; 35 | 36 | void InitMatrix(float * A, float *B, half *Ah, half *Bh, float *C) 37 | { 38 | for (int i = 0; i < M_TOTAL*K_TOTAL; i++) { 39 | A[i] = rand() % 1000 / 1000.0f; 40 | Ah[i] = __float2half(A[i]); 41 | } 42 | for (int i = 0; i < K_TOTAL*N_TOTAL; i++) { 43 | B[i] = rand() % 1000 / 1000.0f; 44 | Bh[i] = __float2half(B[i]); 45 | } 46 | for (int i = 0; i < M_TOTAL*N_TOTAL; i++) 47 | C[i] = rand() % 1000 / 1000.0f; 48 | } 49 | 50 | 51 | // Tensor core 52 | __global__ void WMMAF16TensorCore(half *A, half *B, float *C, float *D) 53 | { 54 | int ix = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE; 55 | int iy = (blockIdx.y * blockDim.y + threadIdx.y); 56 | 57 | wmma::fragment a_frag; 58 | wmma::fragment b_frag; 59 | wmma::fragment ab_frag; 60 | wmma::fragment c_frag; 61 | 62 | wmma::fill_fragment(ab_frag, 0.0f); 63 | 64 | // AB = A*B 65 | int a_col, a_row, b_col, b_row, c_col, c_row; 66 | a_row = ix * M; 67 | b_row = iy * N; 68 | for (int k = 0; k>>(A, B, C, D); 114 | cuda_status = cudaDeviceSynchronize(); 115 | 116 | cudaEventRecord(stop); 117 | cudaEventSynchronize(stop); 118 | 119 | float milliseconds = 0; 120 | cudaEventElapsedTime(&milliseconds, start, stop); 121 | 122 | // for Performance Metrics 123 | printf("[+] GPU(with Tensor Cores) Elapsed Time: %f ms\n", milliseconds); 124 | // references from https://devblogs.nvidia.com/how-implement-performance-metrics-cuda-cc/ 125 | printf("[+] TFLOPS: %.2f\n", ((double)M_TOTAL * N_TOTAL* K_TOTAL * 2) / milliseconds / 1e9); 126 | cudaEventDestroy(start); 127 | cudaEventDestroy(stop); 128 | 129 | return cuda_status; 130 | } 131 | 132 | // GPU version without Tensor Core 133 | __global__ void cuda_matrix_mul(const half *A, const half *B, float *R) 134 | { 135 | int bId = blockIdx.y * gridDim.x + blockIdx.x; 136 | float sum = __half2float(A[blockIdx.y* M_TOTAL + threadIdx.x] * B[threadIdx.x * N_TOTAL + blockIdx.x]); 137 | __syncthreads(); 138 | //printf("Thread %d In block %d : R = %d, sum = %d\n", threadIdx.x, bId, temp, sum); 139 | atomicAdd(&R[bId], sum); 140 | 141 | } 142 | 143 | __global__ void cuda_matrix_add(const float *A, const float *B, float *R) 144 | { 145 | int ix = blockIdx.x * blockDim.x + threadIdx.x; 146 | int iy = blockIdx.y * blockDim.y + threadIdx.y; 147 | int idx = iy * blockDim.x + ix; 148 | R[idx] = A[idx] + B[idx]; 149 | } 150 | 151 | 152 | cudaError_t CalcByCUDA(half *A, half *B, float *C, float *R) 153 | { 154 | cudaError_t cuda_status; 155 | dim3 gridDim, blockDim; 156 | 157 | blockDim.x = 4 * WARP_SIZE; 158 | blockDim.y = 4; 159 | 160 | gridDim.x = (M_TOTAL + (M * blockDim.x / WARP_SIZE - 1)) / (M * blockDim.x / WARP_SIZE); 161 | gridDim.y = (N_TOTAL + N * blockDim.y - 1) / (N * blockDim.y); 162 | 163 | // for Performance Metrics 164 | cudaEvent_t start, stop; 165 | cudaEventCreate(&start); 166 | cudaEventCreate(&stop); 167 | cudaEventRecord(start); 168 | 169 | 170 | cuda_matrix_mul <<> > (A, B, R); 171 | cuda_status = cudaDeviceSynchronize(); 172 | cuda_matrix_add<<> > (R, C, R); 173 | cuda_status = cudaDeviceSynchronize(); 174 | 175 | cudaEventRecord(stop); 176 | cudaEventSynchronize(stop); 177 | 178 | float milliseconds = 0; 179 | cudaEventElapsedTime(&milliseconds, start, stop); 180 | 181 | printf("[+] GPU(without Tensor Cores) Elapsed Time: %f ms\n", milliseconds); 182 | printf("[+] TFLOPS: %.2f\n", ((double)M_TOTAL * N_TOTAL* K_TOTAL * 2) / milliseconds / 1e9); 183 | cudaEventDestroy(start); 184 | cudaEventDestroy(stop); 185 | 186 | return cuda_status; 187 | } 188 | 189 | 190 | 191 | // CPU version 192 | void matrix_add(const float *A, const float *B, unsigned int row, unsigned int col, float *R) { 193 | for (unsigned int c = 0; c < col; c++) { 194 | for (unsigned int r = 0; r < row; r++) { 195 | unsigned int i = c*row + r; 196 | R[i] = A[i] + B[i]; 197 | } 198 | } 199 | } 200 | 201 | void matrix_mul(const float *A, unsigned int a_row, unsigned int a_col, const float *B, unsigned int b_row, unsigned int b_col, float *R) { 202 | memset(R, 0, a_col*b_row * sizeof(float)); 203 | for (unsigned int c = 0; c < a_col; c++) { 204 | for (unsigned int r = 0; r < b_row; r++) { 205 | unsigned int index = c * b_row + r; 206 | for (unsigned int i = 0; i < a_row; i++) { 207 | R[index] += A[c*a_row + i] * B[i*b_row + r]; 208 | } 209 | } 210 | } 211 | } 212 | 213 | void CalcByCPU(float *A, float *B, float *C, float *D) 214 | { 215 | matrix_mul(A, M_TOTAL, K_TOTAL, B, K_TOTAL, N_TOTAL, D); 216 | matrix_add(D, C, M_TOTAL, N_TOTAL, D); 217 | } 218 | 219 | 220 | int main() 221 | { 222 | cudaError_t cuda_status; 223 | cuda_status = cudaSetDevice(0); 224 | if (cuda_status != cudaSuccess) { 225 | printf("cudaSetDevice failed! "); 226 | return 1; 227 | } 228 | 229 | // float Matrix on host for cpu version 230 | float *hostA = (float*)malloc(sizeof(float) * M_TOTAL * K_TOTAL); 231 | float *hostB = (float*)malloc(sizeof(float) * K_TOTAL * N_TOTAL); 232 | float *hostD = (float*)malloc(sizeof(float) * M_TOTAL * N_TOTAL); 233 | 234 | // Matrix on device 235 | half *A; 236 | half *B; 237 | float *C; 238 | float *D; 239 | float *D2; 240 | 241 | // CUDA Unified Memory 242 | cudaMallocManaged((void **)&A, sizeof(half) * M_TOTAL * K_TOTAL); 243 | cudaMallocManaged((void **)&B, sizeof(half) * K_TOTAL * N_TOTAL); 244 | cudaMallocManaged((void **)&C, sizeof(float) * M_TOTAL * N_TOTAL); 245 | cudaMallocManaged((void **)&D, sizeof(float) * M_TOTAL * N_TOTAL); 246 | cudaMallocManaged((void **)&D2, sizeof(float) * M_TOTAL * N_TOTAL); 247 | 248 | // Init matrix A B C on host 249 | //InitHostMatrix(host_A, host_B, host_C); 250 | printf("[*] Initializing Matrix...\n"); 251 | InitMatrix(hostA, hostB, A, B, C); 252 | printf("[+] A: %d x %d\n", M_TOTAL, K_TOTAL); 253 | printf("[+] B: %d x %d\n", K_TOTAL, N_TOTAL); 254 | printf("[+] C: %d x %d\n", M_TOTAL, N_TOTAL); 255 | 256 | // computing with CUDA 257 | printf("[*] Computing D = A * B + C on GPU without Tensor Cores...\n"); 258 | cuda_status = CalcByCUDA(A, B, C, D2); 259 | 260 | // computing with tensor core 261 | printf("[*] Computing D = A * B + C on GPU with Tensor Cores...\n"); 262 | // D = A * B +C, D holds the result after ret 263 | cuda_status = CalcByWMMA(A, B, C, D); 264 | 265 | 266 | // computing with CPU 267 | printf("[*] Computing D = A * B + C on CPU..."); 268 | int begintime, endtime; 269 | begintime = clock(); 270 | CalcByCPU(hostA, hostB, C, hostD); 271 | endtime = clock(); 272 | printf("OK\n"); 273 | printf("[*] CPU Elapsed Time: %fs\n", (endtime-begintime)/1000.0f); 274 | 275 | // Verification 276 | printf("[*] Verifying result...\n"); 277 | for (int i = 0; i < M_TOTAL * N_TOTAL; i++) { 278 | if (fabs(D[i] - hostD[i]) > 0.1f) 279 | printf("[-] Mismatch index=%d TensorCore=%f HOST=%f\n", i, D[i], hostD[i]); 280 | } 281 | printf("[+] Verification End\n"); 282 | 283 | cuda_status = cudaDeviceReset(); 284 | if (cuda_status != cudaSuccess) { 285 | printf("cudaDeviceReset failed! "); 286 | return 1; 287 | } 288 | 289 | cudaFree(A); 290 | cudaFree(B); 291 | cudaFree(C); 292 | cudaFree(D); 293 | cudaFree(D2); 294 | 295 | free(hostA); 296 | free(hostB); 297 | free(hostD); 298 | 299 | return 0; 300 | } 301 | --------------------------------------------------------------------------------