├── 1.basic_function ├── check_result.py ├── readme.md ├── sparse_mmad.cu └── sparse_mmad_data.py ├── 2.microbenchmark ├── check_result.py ├── sparse_mmad16832_data.py ├── test_sp_mmad_16832_flops.cu ├── test_sp_mmad_16832_function.cu └── test_sp_mmad_16832_latency.cu └── LICENSE /1.basic_function/check_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | if __name__ == '__main__': 4 | golden = np.fromfile('d.bin', dtype='float16') 5 | gpu_result = np.fromfile('d_gpu.bin', dtype='float16') 6 | print(golden.reshape(16, 8)) 7 | print(gpu_result.reshape(16, 8)) 8 | diff = np.abs(golden - gpu_result).mean() 9 | print('diff: {}'.format(diff)) -------------------------------------------------------------------------------- /1.basic_function/readme.md: -------------------------------------------------------------------------------- 1 | # demo of Ampere GPU's sparse matmul 2 | 3 | ## run the code 4 | 1. generate test data 5 | ``` 6 | python3 sparse_mmad_data.py 7 | ``` 8 | 2. compile test code (sm_80 for A100, plz check your gpu's arch) 9 | ``` 10 | nvcc -arch sm_80 sparse_mmad.cu 11 | ``` 12 | 3. run the test program 13 | ``` 14 | ./a.out 15 | ``` 16 | 4. check result 17 | ``` 18 | python3 check_result.py 19 | ``` -------------------------------------------------------------------------------- /1.basic_function/sparse_mmad.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // code from 8 | // https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 9 | #define gpuErrchk(ans) \ 10 | { gpuAssert((ans), __FILE__, __LINE__); } 11 | inline void gpuAssert(cudaError_t code, const char *file, int line, 12 | bool abort = true) { 13 | if (code != cudaSuccess) { 14 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, 15 | line); 16 | if (abort) 17 | exit(code); 18 | } 19 | } 20 | 21 | static const int M = 16; 22 | static const int N = 8; 23 | static const int K = 16; 24 | 25 | __global__ void sparse_mmad(__half *d, __half *a, __half *b, __half *c, 26 | uint32_t *metadata_p) { 27 | uint32_t tid = threadIdx.x; 28 | uint32_t metadata = metadata_p[tid / 4]; 29 | __half *a_ptr = a + (tid % 4) * 2 + (tid / 4) * 8; 30 | __half *b_ptr = b + (tid % 4) * 2 * 8 + tid / 4; 31 | __half *c_ptr = c + (tid % 4) * 2 + (tid / 4) * 8; 32 | __half *d_ptr = d + (tid % 4) * 2 + (tid / 4) * 8; 33 | asm volatile("{\n\t" 34 | ".reg .f16 %Ra_single<4>, %Rb_single<4>;\n\t" 35 | ".reg .f16x2 %Ra<2>, %Rb<2>, %Rc<2>, %Rd<2>;\n\t" 36 | "ld.global.ca.b32 %Ra0, [%1];\n\t" 37 | "ld.global.ca.b32 %Ra1, [%1 + 128];\n\t" 38 | "ld.global.ca.b16 %Rb_single0, [%2];\n\t" 39 | "ld.global.ca.b16 %Rb_single1, [%2 + 16];\n\t" 40 | "ld.global.ca.b16 %Rb_single2, [%2 + 128];\n\t" 41 | "ld.global.ca.b16 %Rb_single3, [%2 + 144];\n\t" 42 | "ld.global.ca.b32 %Rc0, [%3];\n\t" 43 | "ld.global.ca.b32 %Rc1, [%3 + 128];\n\t" 44 | "mov.b32 %Rb0, {%Rb_single0, %Rb_single1};\n\t" 45 | "mov.b32 %Rb1, {%Rb_single2, %Rb_single3};\n\t" 46 | "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\n\t" 47 | "{%Rd0, %Rd1},\n\t" 48 | "{%Ra0, %Ra1},\n\t" 49 | "{%Rb0, %Rb1},\n\t" 50 | "{%Rc0, %Rc1}, %4, 0x0;\n\t" 51 | "st.global.wb.b32 [%0], %Rd0;\n\t" 52 | "st.global.wb.b32 [%0 + 128], %Rd1;\n\t" 53 | "}\n\t" 54 | : 55 | : "l"(d_ptr), "l"(a_ptr), "l"(b_ptr), "l"(c_ptr), "r"(metadata) 56 | : "memory"); 57 | } 58 | 59 | int main(int argc, char **argv) { 60 | size_t mat_a_size = M * K / 2; 61 | size_t mat_b_size = N * K; 62 | size_t mat_c_d_size = M * N; 63 | size_t metadata_size_bytes = M * 2; // 16 bit per row 64 | 65 | __half *mat_a_host = new __half[mat_a_size]; 66 | __half *mat_b_host = new __half[mat_b_size]; 67 | __half *mat_c_host = new __half[mat_c_d_size]; 68 | __half *mat_d_host = new __half[mat_c_d_size]; 69 | uint32_t *metadata_host = 70 | new uint32_t[metadata_size_bytes / sizeof(uint32_t)]; 71 | std::ifstream a_fs("a.bin", std::ios_base::binary); 72 | a_fs.read((char *)mat_a_host, mat_a_size * sizeof(__half)); 73 | std::ifstream b_fs("b.bin", std::ios_base::binary); 74 | b_fs.read((char *)mat_b_host, mat_b_size * sizeof(__half)); 75 | std::ifstream c_fs("c.bin", std::ios_base::binary); 76 | c_fs.read((char *)mat_c_host, mat_c_d_size * sizeof(__half)); 77 | std::ifstream metadata_fs("metadata.bin", std::ios_base::binary); 78 | metadata_fs.read((char *)metadata_host, metadata_size_bytes); 79 | 80 | __half *mat_a_dev; 81 | __half *mat_b_dev; 82 | __half *mat_c_dev; 83 | __half *mat_d_dev; 84 | uint32_t *metadata_dev; 85 | 86 | gpuErrchk(cudaMalloc(&mat_a_dev, mat_a_size * sizeof(__half))); 87 | gpuErrchk(cudaMalloc(&mat_b_dev, mat_b_size * sizeof(__half))); 88 | gpuErrchk(cudaMalloc(&mat_c_dev, mat_c_d_size * sizeof(__half))); 89 | gpuErrchk(cudaMalloc(&mat_d_dev, mat_c_d_size * sizeof(__half))); 90 | gpuErrchk(cudaMalloc(&metadata_dev, metadata_size_bytes)); 91 | 92 | gpuErrchk(cudaMemcpy(mat_a_dev, mat_a_host, mat_a_size * sizeof(__half), 93 | cudaMemcpyHostToDevice)); 94 | gpuErrchk(cudaMemcpy(mat_b_dev, mat_b_host, mat_b_size * sizeof(__half), 95 | cudaMemcpyHostToDevice)); 96 | gpuErrchk(cudaMemcpy(mat_c_dev, mat_c_host, mat_c_d_size * sizeof(__half), 97 | cudaMemcpyHostToDevice)); 98 | gpuErrchk(cudaMemcpy(metadata_dev, metadata_host, metadata_size_bytes, 99 | cudaMemcpyHostToDevice)); 100 | 101 | sparse_mmad<<<1, 32>>>(mat_d_dev, mat_a_dev, mat_b_dev, mat_c_dev, 102 | metadata_dev); 103 | gpuErrchk(cudaDeviceSynchronize()); 104 | 105 | gpuErrchk(cudaMemcpy(mat_d_host, mat_d_dev, mat_c_d_size * sizeof(__half), 106 | cudaMemcpyDeviceToHost)); 107 | std::ofstream d_fs("d_gpu.bin", std::ios_base::binary); 108 | d_fs.write((char *)mat_d_host, mat_c_d_size * sizeof(__half)); 109 | 110 | gpuErrchk(cudaFree(mat_a_dev)); 111 | gpuErrchk(cudaFree(mat_b_dev)); 112 | gpuErrchk(cudaFree(mat_c_dev)); 113 | gpuErrchk(cudaFree(mat_d_dev)); 114 | gpuErrchk(cudaFree(metadata_dev)); 115 | 116 | delete[] mat_a_host; 117 | delete[] mat_b_host; 118 | delete[] mat_c_host; 119 | delete[] mat_d_host; 120 | delete[] metadata_host; 121 | 122 | return 0; 123 | } -------------------------------------------------------------------------------- /1.basic_function/sparse_mmad_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.fromnumeric import nonzero 3 | 4 | M = 16 5 | N = 8 6 | K = 16 7 | 8 | dtype = 'float16' 9 | 10 | 11 | def gen_4in2_metadata(): 12 | metadata = np.random.permutation(np.arange(4))[:2].astype('uint8') 13 | metadata.sort() 14 | return metadata 15 | 16 | 17 | def make_sparse_metadata(row, col): 18 | metadata_col = col // 4 19 | return [[gen_4in2_metadata() for _ in range(metadata_col)] 20 | for _ in range(row)] 21 | 22 | 23 | def set_zero_by_metadata(mat, metadata): 24 | row, col = mat.shape 25 | non_zero_mat = np.zeros((row, col), dtype=dtype) 26 | for row_index in range(row): 27 | metadata_row = metadata[row_index] 28 | for i, metadata_block in enumerate(metadata_row): 29 | offset = i * 4 30 | for idx in metadata_block: 31 | non_zero_mat[row_index, offset + 32 | idx] = mat[row_index, offset + idx] 33 | return non_zero_mat 34 | 35 | 36 | def make_sparse_mat(row, col): 37 | mat = np.random.rand(row, col).astype(dtype) 38 | metadata = make_sparse_metadata(row, col) 39 | mat = set_zero_by_metadata(mat, metadata) 40 | return mat, metadata 41 | 42 | 43 | def compress_sparse_mat(mat, metadata): 44 | row, col = mat.shape 45 | non_zero_mat = np.zeros((row, col//2), dtype=dtype) 46 | for row_index in range(row): 47 | metadata_row = metadata[row_index] 48 | for i, metadata_block in enumerate(metadata_row): 49 | offset = i * 4 50 | for ii, idx in enumerate(metadata_block): 51 | non_zero_mat[row_index, i * 2 + 52 | ii] = mat[row_index, offset + idx] 53 | return non_zero_mat 54 | 55 | 56 | def metadata_to_binary(metadata): 57 | row = len(metadata) 58 | col = len(metadata[0]) * 2 59 | size = row * col // 16 60 | half_row = row // 2 61 | bin_meta = np.zeros((size,), dtype='uint32') 62 | for row_id in range(half_row): 63 | bit_offset = 0 64 | first_half_row = np.concatenate(metadata[row_id]) 65 | second_half_row = np.concatenate(metadata[row_id + half_row]) 66 | whole_row = np.concatenate([first_half_row, second_half_row]) 67 | for idx in whole_row: 68 | bin_meta[row_id] |= idx << bit_offset 69 | bit_offset += 2 70 | return bin_meta 71 | 72 | 73 | def make_dense_mat(row, col): 74 | return np.random.rand(row, col).astype(dtype) 75 | 76 | 77 | if __name__ == '__main__': 78 | mat_a, metadata = make_sparse_mat(M, K) 79 | compressed_mat_a = compress_sparse_mat(mat_a, metadata) 80 | bin_meta = metadata_to_binary(metadata) 81 | mat_b = make_dense_mat(K, N) 82 | mat_c = make_dense_mat(M, N) 83 | mat_d = np.matmul(mat_a, mat_b) + mat_c 84 | 85 | compressed_mat_a.tofile('a.bin') 86 | bin_meta.tofile('metadata.bin') 87 | mat_b.tofile('b.bin') 88 | mat_c.tofile('c.bin') 89 | mat_d.tofile('d.bin') 90 | -------------------------------------------------------------------------------- /2.microbenchmark/check_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | if __name__ == '__main__': 4 | golden = np.fromfile('d.bin', dtype='float16') 5 | gpu_result = np.fromfile('d_gpu.bin', dtype='float16') 6 | print(golden.reshape(16, 8)) 7 | print(gpu_result.reshape(16, 8)) 8 | diff = np.abs(golden - gpu_result).mean() 9 | print('diff: {}'.format(diff)) -------------------------------------------------------------------------------- /2.microbenchmark/sparse_mmad16832_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.fromnumeric import nonzero 3 | import argparse 4 | 5 | 6 | def gen_4in2_metadata(): 7 | metadata = np.random.permutation(np.arange(4))[:2].astype('uint8') 8 | metadata.sort() 9 | return metadata 10 | 11 | 12 | def make_sparse_metadata(row, col): 13 | metadata_col = col // 4 14 | return [[gen_4in2_metadata() for _ in range(metadata_col)] 15 | for _ in range(row)] 16 | 17 | 18 | def set_zero_by_metadata(mat, metadata, dtype): 19 | row, col = mat.shape 20 | non_zero_mat = np.zeros((row, col), dtype=dtype) 21 | for row_index in range(row): 22 | metadata_row = metadata[row_index] 23 | for i, metadata_block in enumerate(metadata_row): 24 | offset = i * 4 25 | for idx in metadata_block: 26 | non_zero_mat[row_index, offset + 27 | idx] = mat[row_index, offset + idx] 28 | return non_zero_mat 29 | 30 | 31 | def make_sparse_mat(row, col, dtype): 32 | mat = np.random.rand(row, col).astype(dtype) 33 | metadata = make_sparse_metadata(row, col) 34 | mat = set_zero_by_metadata(mat, metadata, dtype) 35 | return mat, metadata 36 | 37 | 38 | def compress_sparse_mat(mat, metadata, dtype): 39 | row, col = mat.shape 40 | non_zero_mat = np.zeros((row, col//2), dtype=dtype) 41 | for row_index in range(row): 42 | metadata_row = metadata[row_index] 43 | for i, metadata_block in enumerate(metadata_row): 44 | offset = i * 4 45 | for ii, idx in enumerate(metadata_block): 46 | non_zero_mat[row_index, i * 2 + 47 | ii] = mat[row_index, offset + idx] 48 | return non_zero_mat 49 | 50 | 51 | def metadata_to_binary(metadata): 52 | row = len(metadata) 53 | col = len(metadata[0]) * 2 54 | size = row * col // 16 55 | half_row_num = row // 2 56 | half_col_num = col // 2 // 2 57 | bin_meta = np.zeros((size,), dtype='uint32') 58 | for row_id in range(half_row_num): 59 | for sub_col in range(2): 60 | bit_offset = 0 61 | first_half_row = np.concatenate(metadata[row_id][sub_col * half_col_num: (sub_col + 1) * half_col_num]) 62 | second_half_row = np.concatenate(metadata[row_id + half_row_num][sub_col * half_col_num: (sub_col + 1) * half_col_num]) 63 | whole_row = np.concatenate([first_half_row, second_half_row]) 64 | for idx in whole_row: 65 | bin_meta[row_id * 2 + sub_col] |= idx << bit_offset 66 | bit_offset += 2 67 | return bin_meta 68 | 69 | 70 | def make_dense_mat(row, col, dtype): 71 | return np.random.rand(row, col).astype(dtype) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('-n', type=int, required=True) 77 | parser.add_argument('-m', type=int, required=True) 78 | parser.add_argument('-k', type=int, required=True) 79 | parser.add_argument('--dtype', default='float16', required=False) 80 | args = parser.parse_args() 81 | dtype = args.dtype 82 | N = args.n 83 | M = args.m 84 | K = args.k 85 | mat_a, metadata = make_sparse_mat(M, K, dtype) 86 | compressed_mat_a = compress_sparse_mat(mat_a, metadata, dtype) 87 | bin_meta = metadata_to_binary(metadata) 88 | mat_b = make_dense_mat(K, N, dtype) 89 | #mat_c = make_dense_mat(M, N, dtype) 90 | mat_c = np.zeros((M, N), dtype) 91 | mat_d = np.matmul(mat_a, mat_b) + mat_c 92 | 93 | compressed_mat_a.tofile('a.bin') 94 | bin_meta.tofile('metadata.bin') 95 | print(bin_meta) 96 | mat_b.tofile('b.bin') 97 | mat_c.tofile('c.bin') 98 | mat_d.tofile('d.bin') 99 | -------------------------------------------------------------------------------- /2.microbenchmark/test_sp_mmad_16832_flops.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // code from 9 | // https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 10 | #define gpuErrchk(ans) \ 11 | { gpuAssert((ans), __FILE__, __LINE__); } 12 | inline void gpuAssert(cudaError_t code, const char *file, int line, 13 | bool abort = true) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, 16 | line); 17 | if (abort) 18 | exit(code); 19 | } 20 | } 21 | 22 | static const int M = 16; 23 | static const int N = 8; 24 | static const int K = 32; 25 | static const uint64_t total_repeat_time = 409600000; 26 | static const uint32_t total_thread_num = 128; 27 | static const uint32_t warp_size = 32; 28 | static const uint32_t sm_count = 108; 29 | static const uint32_t block_num = sm_count * 1; 30 | static const uint32_t warp_num = total_thread_num / warp_size; 31 | static const uint32_t max_unroll_times = 8; 32 | 33 | // I didn't check result since it is just a throughput test 34 | template 35 | __global__ void sp_mmad_16832_throughput_test(__half *d, __half *a, __half *b, 36 | __half *c, uint32_t *metadata_p, 37 | uint64_t repeat_time) { 38 | uint32_t tid = threadIdx.x % warp_size; 39 | 40 | size_t mat_a_stride = M * K / 2; 41 | size_t mat_b_stride = N * K; 42 | size_t mat_c_stride = M * N; 43 | size_t mat_d_stride = M * N; 44 | size_t metadata_stride = M * 2 / sizeof(uint32_t); // 16 bit per row 45 | 46 | uint32_t metadata[max_unroll_times]; 47 | uint32_t a01[max_unroll_times], a23[max_unroll_times], a45[max_unroll_times], 48 | a67[max_unroll_times]; 49 | uint32_t b01[max_unroll_times], b23[max_unroll_times], b45[max_unroll_times], 50 | b67[max_unroll_times]; 51 | uint32_t d01[max_unroll_times], d23[max_unroll_times]; 52 | 53 | size_t mat_a_row = K / 2; 54 | size_t mat_b_row = N; 55 | 56 | __half *a_base_ptr = a + (tid % 4) * 2 + (tid / 4) * mat_a_row; 57 | __half *b_base_ptr = b + (tid % 4) * 2 * mat_b_row + tid / 4; 58 | __half *c_base_ptr = c + (tid % 4) * 2 + (tid / 4) * 8; 59 | __half *d_base_ptr = d + (tid % 4) * 2 + (tid / 4) * 8; 60 | 61 | for (int i = 0; i < unroll_times; ++i) { 62 | uint32_t *metadata_ptr = metadata_p + i * metadata_stride; 63 | __half *a_ptr = a_base_ptr + i * mat_a_stride; 64 | __half *b_ptr = b_base_ptr + i * mat_b_stride; 65 | __half *c_ptr = c_base_ptr + i * mat_c_stride; 66 | 67 | metadata[i] = metadata_ptr[tid / 2]; 68 | 69 | a01[i] = *((uint32_t *)a_ptr); 70 | a23[i] = *((uint32_t *)(a_ptr + 128)); 71 | a45[i] = *((uint32_t *)(a_ptr + 8)); 72 | a67[i] = *((uint32_t *)(a_ptr + 128 + 8)); 73 | 74 | uint16_t b0, b1, b2, b3, b4, b5, b6, b7; 75 | 76 | b0 = *((uint16_t *)(b_ptr)); 77 | b1 = *((uint16_t *)(b_ptr + 8)); 78 | b2 = *((uint16_t *)(b_ptr + 64)); 79 | b3 = *((uint16_t *)(b_ptr + 64 + 8)); 80 | b4 = *((uint16_t *)(b_ptr + 64 * 2)); 81 | b5 = *((uint16_t *)(b_ptr + 64 * 2 + 8)); 82 | b6 = *((uint16_t *)(b_ptr + 64 * 3)); 83 | b7 = *((uint16_t *)(b_ptr + 64 * 3 + 8)); 84 | 85 | asm volatile("mov.b32 %0, {%4, %5};\n\t" 86 | "mov.b32 %1, {%6, %7};\n\t" 87 | "mov.b32 %2, {%8, %9};\n\t" 88 | "mov.b32 %3, {%10, %11};\n\t" 89 | : "=r"(b01[i]), "=r"(b23[i]), "=r"(b45[i]), "=r"(b67[i]) 90 | : "h"(b0), "h"(b1), "h"(b2), "h"(b3), "h"(b4), "h"(b5), 91 | "h"(b6), "h"(b7) 92 | :); 93 | 94 | d01[i] = *((uint32_t *)c_ptr); 95 | d23[i] = *((uint32_t *)(c_ptr + 64)); 96 | } 97 | 98 | asm volatile("bar.sync 0;"); 99 | 100 | for (uint64_t repeat_i = 0; repeat_i < repeat_time; ++repeat_i) { 101 | for (int i = 0; i < unroll_times; ++i) { 102 | asm volatile("{\n\t" 103 | "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16\n\t" 104 | "{%0, %1},\n\t" 105 | "{%2, %3, %4, %5},\n\t" 106 | "{%6, %7, %8, %9},\n\t" 107 | "{%10, %11}, %12, 0x0;\n\t" 108 | "}\n\t" 109 | : "=r"(d01[i]), "=r"(d23[i]) 110 | : "r"(a01[i]), "r"(a23[i]), "r"(a45[i]), "r"(a67[i]), 111 | "r"(b01[i]), "r"(b23[i]), "r"(b45[i]), "r"(b67[i]), 112 | "r"(d01[i]), "r"(d23[i]), "r"(metadata[i]) 113 | :); 114 | } 115 | } 116 | 117 | asm volatile("bar.sync 0;"); 118 | 119 | for (int i = 0; i < unroll_times; ++i) { 120 | __half *d_ptr = d_base_ptr + i * mat_d_stride; 121 | *((uint32_t *)d_ptr) = d01[i]; 122 | *((uint32_t *)(d_ptr + 64)) = d23[i]; 123 | } 124 | } 125 | 126 | template void test_sp_mmad_16832_flops() { 127 | size_t mat_a_size = M * K / 2; 128 | size_t mat_b_size = N * K; 129 | size_t mat_c_size = M * N; 130 | size_t mat_d_size = M * N; 131 | size_t metadata_size_bytes = M * K / 8; // 32 bit per row 132 | 133 | size_t mat_a_unroll_size = mat_a_size * unroll_times; 134 | size_t mat_b_unroll_size = mat_b_size * unroll_times; 135 | size_t mat_c_unroll_size = mat_c_size * unroll_times; 136 | size_t mat_d_unroll_size = mat_d_size * unroll_times; 137 | size_t metadata_unroll_size_bytes = metadata_size_bytes * unroll_times; 138 | 139 | uint32_t *duration_host = new uint32_t[warp_num]; 140 | 141 | __half *mat_a_host = new __half[mat_a_size]; 142 | __half *mat_b_host = new __half[mat_b_size]; 143 | __half *mat_c_host = new __half[mat_c_size]; 144 | __half *mat_d_host = new __half[mat_d_size]; 145 | uint32_t *metadata_host = 146 | new uint32_t[metadata_size_bytes / sizeof(uint32_t)]; 147 | 148 | std::ifstream a_fs("a.bin", std::ios_base::binary); 149 | a_fs.read((char *)mat_a_host, mat_a_size * sizeof(__half)); 150 | std::ifstream b_fs("b.bin", std::ios_base::binary); 151 | b_fs.read((char *)mat_b_host, mat_b_size * sizeof(__half)); 152 | std::ifstream c_fs("c.bin", std::ios_base::binary); 153 | c_fs.read((char *)mat_c_host, mat_c_size * sizeof(__half)); 154 | std::ifstream metadata_fs("metadata.bin", std::ios_base::binary); 155 | metadata_fs.read((char *)metadata_host, metadata_size_bytes); 156 | 157 | __half *mat_a_dev; 158 | __half *mat_b_dev; 159 | __half *mat_c_dev; 160 | __half *mat_d_dev; 161 | uint32_t *metadata_dev; 162 | 163 | gpuErrchk(cudaMalloc(&mat_a_dev, mat_a_unroll_size * sizeof(__half))); 164 | gpuErrchk(cudaMalloc(&mat_b_dev, mat_b_unroll_size * sizeof(__half))); 165 | gpuErrchk(cudaMalloc(&mat_c_dev, mat_c_unroll_size * sizeof(__half))); 166 | gpuErrchk(cudaMalloc(&mat_d_dev, mat_d_unroll_size * sizeof(__half))); 167 | gpuErrchk(cudaMalloc(&metadata_dev, metadata_unroll_size_bytes)); 168 | 169 | for (uint32_t i = 0; i < unroll_times; ++i) { 170 | // uncomment to use random data, but perfomance may decrease due to power 171 | // limit 172 | /* 173 | gpuErrchk(cudaMemcpy(mat_a_dev + i * mat_a_size, mat_a_host, 174 | mat_a_size * sizeof(__half), cudaMemcpyHostToDevice)); 175 | gpuErrchk(cudaMemcpy(mat_b_dev + i * mat_b_size, mat_b_host, 176 | mat_b_size * sizeof(__half), cudaMemcpyHostToDevice)); 177 | gpuErrchk(cudaMemcpy(mat_c_dev + i * mat_c_size, mat_c_host, 178 | mat_c_size * sizeof(__half), cudaMemcpyHostToDevice)); 179 | */ 180 | gpuErrchk( 181 | cudaMemcpy(metadata_dev + i * metadata_size_bytes / sizeof(uint32_t), 182 | metadata_host, metadata_size_bytes, cudaMemcpyHostToDevice)); 183 | } 184 | 185 | uint64_t repeat_time = total_repeat_time / unroll_times; 186 | 187 | auto t_start = std::chrono::high_resolution_clock::now(); 188 | 189 | sp_mmad_16832_throughput_test<<>>( 190 | mat_d_dev, mat_a_dev, mat_b_dev, mat_c_dev, metadata_dev, repeat_time); 191 | gpuErrchk(cudaDeviceSynchronize()); 192 | auto t_end = std::chrono::high_resolution_clock::now(); 193 | double gpu_ns = (t_end - t_start).count(); 194 | std::cout << "kernel duration: " << gpu_ns << " ns" << std::endl; 195 | 196 | double flop_per_repeat = 197 | unroll_times * warp_num * M * K * N * 2; // x2 because fma is 2ops 198 | double total_flop = flop_per_repeat * repeat_time * block_num; 199 | std::cout << "unroll: " << unroll_times 200 | << " flops(whole GPU): " << total_flop / (gpu_ns / 1E9) / 1E12 201 | << " TFLOPS" << std::endl; 202 | 203 | gpuErrchk(cudaFree(mat_a_dev)); 204 | gpuErrchk(cudaFree(mat_b_dev)); 205 | gpuErrchk(cudaFree(mat_c_dev)); 206 | gpuErrchk(cudaFree(mat_d_dev)); 207 | gpuErrchk(cudaFree(metadata_dev)); 208 | 209 | delete[] mat_a_host; 210 | delete[] mat_b_host; 211 | delete[] mat_c_host; 212 | delete[] mat_d_host; 213 | delete[] metadata_host; 214 | } 215 | 216 | int main(int argc, char **argv) { 217 | test_sp_mmad_16832_flops<1>(); 218 | test_sp_mmad_16832_flops<2>(); 219 | test_sp_mmad_16832_flops<3>(); 220 | test_sp_mmad_16832_flops<4>(); 221 | test_sp_mmad_16832_flops<5>(); 222 | test_sp_mmad_16832_flops<6>(); 223 | test_sp_mmad_16832_flops<7>(); 224 | test_sp_mmad_16832_flops<8>(); 225 | return 0; 226 | } -------------------------------------------------------------------------------- /2.microbenchmark/test_sp_mmad_16832_function.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // code from 9 | // https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 10 | #define gpuErrchk(ans) \ 11 | { gpuAssert((ans), __FILE__, __LINE__); } 12 | inline void gpuAssert(cudaError_t code, const char *file, int line, 13 | bool abort = true) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, 16 | line); 17 | if (abort) 18 | exit(code); 19 | } 20 | } 21 | 22 | static const int M = 16; 23 | static const int N = 8; 24 | static const int K = 32; 25 | static const uint32_t total_thread_num = 32; 26 | static const uint32_t warp_size = 32; 27 | static const uint32_t block_num = 1; 28 | static const uint32_t warp_num = total_thread_num / warp_size; 29 | 30 | // TODO: use shared memory 31 | 32 | __global__ void sp_mmad_16832_test(__half *d, __half *a, __half *b, __half *c, 33 | uint32_t *metadata_p) { 34 | uint32_t tid = threadIdx.x % warp_size; 35 | 36 | uint32_t metadata; 37 | uint32_t a01, a23, a45, a67; 38 | uint32_t b01, b23, b45, b67; 39 | uint32_t d01, d23; 40 | 41 | size_t mat_a_row = K / 2; 42 | size_t mat_b_row = N; 43 | 44 | __half *a_base_ptr = a + (tid % 4) * 2 + (tid / 4) * mat_a_row; 45 | __half *b_base_ptr = b + (tid % 4) * 2 * mat_b_row + tid / 4; 46 | __half *c_base_ptr = c + (tid % 4) * 2 + (tid / 4) * 8; 47 | __half *d_base_ptr = d + (tid % 4) * 2 + (tid / 4) * 8; 48 | 49 | uint32_t *metadata_ptr = metadata_p; 50 | __half *a_ptr = a_base_ptr; 51 | __half *b_ptr = b_base_ptr; 52 | __half *c_ptr = c_base_ptr; 53 | 54 | metadata = metadata_ptr[tid / 4 * 2 + tid % 2]; 55 | 56 | a01 = *((uint32_t *)a_ptr); 57 | a23 = *((uint32_t *)(a_ptr + 128)); 58 | a45 = *((uint32_t *)(a_ptr + 8)); 59 | a67 = *((uint32_t *)(a_ptr + 128 + 8)); 60 | 61 | uint16_t b0, b1, b2, b3, b4, b5, b6, b7; 62 | 63 | b0 = *((uint16_t *)(b_ptr)); 64 | b1 = *((uint16_t *)(b_ptr + 8)); 65 | b2 = *((uint16_t *)(b_ptr + 64)); 66 | b3 = *((uint16_t *)(b_ptr + 64 + 8)); 67 | b4 = *((uint16_t *)(b_ptr + 64 * 2)); 68 | b5 = *((uint16_t *)(b_ptr + 64 * 2 + 8)); 69 | b6 = *((uint16_t *)(b_ptr + 64 * 3)); 70 | b7 = *((uint16_t *)(b_ptr + 64 * 3 + 8)); 71 | 72 | asm volatile("mov.b32 %0, {%4, %5};\n\t" 73 | "mov.b32 %1, {%6, %7};\n\t" 74 | "mov.b32 %2, {%8, %9};\n\t" 75 | "mov.b32 %3, {%10, %11};\n\t" 76 | : "=r"(b01), "=r"(b23), "=r"(b45), "=r"(b67) 77 | : "h"(b0), "h"(b1), "h"(b2), "h"(b3), "h"(b4), "h"(b5), "h"(b6), 78 | "h"(b7) 79 | :); 80 | 81 | d01 = *((uint32_t *)c_ptr); 82 | d23 = *((uint32_t *)(c_ptr + 64)); 83 | 84 | asm volatile("{\n\t" 85 | "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16\n\t" 86 | "{%0, %1},\n\t" 87 | "{%2, %3, %4, %5},\n\t" 88 | "{%6, %7, %8, %9},\n\t" 89 | "{%10, %11}, %12, 0x0;\n\t" 90 | "}\n\t" 91 | : "=r"(d01), "=r"(d23) 92 | : "r"(a01), "r"(a23), "r"(a45), "r"(a67), "r"(b01), "r"(b23), 93 | "r"(b45), "r"(b67), "r"(d01), "r"(d23), "r"(metadata) 94 | :); 95 | 96 | __half *d_ptr = d_base_ptr; 97 | *((uint32_t *)d_ptr) = d01; 98 | *((uint32_t *)(d_ptr + 64)) = d23; 99 | } 100 | 101 | int main(int argc, char **argv) { 102 | size_t mat_a_size = M * K / 2; 103 | size_t mat_b_size = N * K; 104 | size_t mat_c_size = M * N; 105 | size_t mat_d_size = M * N; 106 | size_t metadata_size_bytes = M * K / 8; // 32 bit per row 107 | 108 | __half *mat_a_host = new __half[mat_a_size]; 109 | __half *mat_b_host = new __half[mat_b_size]; 110 | __half *mat_c_host = new __half[mat_c_size]; 111 | __half *mat_d_host = new __half[mat_d_size]; 112 | uint32_t *metadata_host = 113 | new uint32_t[metadata_size_bytes / sizeof(uint32_t)]; 114 | 115 | std::ifstream a_fs("a.bin", std::ios_base::binary); 116 | a_fs.read((char *)mat_a_host, mat_a_size * sizeof(__half)); 117 | std::ifstream b_fs("b.bin", std::ios_base::binary); 118 | b_fs.read((char *)mat_b_host, mat_b_size * sizeof(__half)); 119 | std::ifstream c_fs("c.bin", std::ios_base::binary); 120 | c_fs.read((char *)mat_c_host, mat_c_size * sizeof(__half)); 121 | std::ifstream metadata_fs("metadata.bin", std::ios_base::binary); 122 | metadata_fs.read((char *)metadata_host, metadata_size_bytes); 123 | 124 | __half *mat_a_dev; 125 | __half *mat_b_dev; 126 | __half *mat_c_dev; 127 | __half *mat_d_dev; 128 | uint32_t *metadata_dev; 129 | 130 | gpuErrchk(cudaMalloc(&mat_a_dev, mat_a_size * sizeof(__half))); 131 | gpuErrchk(cudaMalloc(&mat_b_dev, mat_b_size * sizeof(__half))); 132 | gpuErrchk(cudaMalloc(&mat_c_dev, mat_c_size * sizeof(__half))); 133 | gpuErrchk(cudaMalloc(&mat_d_dev, mat_d_size * sizeof(__half))); 134 | gpuErrchk(cudaMalloc(&metadata_dev, metadata_size_bytes)); 135 | 136 | gpuErrchk(cudaMemcpy(mat_a_dev, mat_a_host, mat_a_size * sizeof(__half), 137 | cudaMemcpyHostToDevice)); 138 | gpuErrchk(cudaMemcpy(mat_b_dev, mat_b_host, mat_b_size * sizeof(__half), 139 | cudaMemcpyHostToDevice)); 140 | gpuErrchk(cudaMemcpy(mat_c_dev, mat_c_host, mat_c_size * sizeof(__half), 141 | cudaMemcpyHostToDevice)); 142 | 143 | gpuErrchk(cudaMemcpy(metadata_dev, metadata_host, metadata_size_bytes, 144 | cudaMemcpyHostToDevice)); 145 | 146 | sp_mmad_16832_test<<>>( 147 | mat_d_dev, mat_a_dev, mat_b_dev, mat_c_dev, metadata_dev); 148 | gpuErrchk(cudaDeviceSynchronize()); 149 | gpuErrchk(cudaMemcpy(mat_d_host, mat_d_dev, mat_d_size * sizeof(__half), 150 | cudaMemcpyDeviceToHost)); 151 | std::ofstream d_fs("d_gpu.bin", std::ios_base::binary); 152 | d_fs.write((char *)mat_d_host, mat_d_size * sizeof(__half)); 153 | 154 | gpuErrchk(cudaFree(mat_a_dev)); 155 | gpuErrchk(cudaFree(mat_b_dev)); 156 | gpuErrchk(cudaFree(mat_c_dev)); 157 | gpuErrchk(cudaFree(mat_d_dev)); 158 | gpuErrchk(cudaFree(metadata_dev)); 159 | 160 | delete[] mat_a_host; 161 | delete[] mat_b_host; 162 | delete[] mat_c_host; 163 | delete[] mat_d_host; 164 | delete[] metadata_host; 165 | 166 | return 0; 167 | } -------------------------------------------------------------------------------- /2.microbenchmark/test_sp_mmad_16832_latency.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // code from 9 | // https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 10 | #define gpuErrchk(ans) \ 11 | { gpuAssert((ans), __FILE__, __LINE__); } 12 | inline void gpuAssert(cudaError_t code, const char *file, int line, 13 | bool abort = true) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, 16 | line); 17 | if (abort) 18 | exit(code); 19 | } 20 | } 21 | 22 | static const int M = 16; 23 | static const int N = 8; 24 | static const int K = 32; 25 | static const uint64_t repeat_time = 409600; 26 | static const uint32_t total_thread_num = 32; 27 | static const uint32_t warp_size = 32; 28 | static const uint32_t block_num = 1; 29 | static const uint32_t warp_num = total_thread_num / warp_size; 30 | 31 | // TODO: use shared memory 32 | 33 | __global__ void sp_mmad_16832_latency_test(__half *d, __half *a, __half *b, 34 | __half *c, uint32_t *metadata_p, 35 | uint32_t *duration) { 36 | uint32_t tid = threadIdx.x % warp_size; 37 | 38 | uint32_t metadata; 39 | uint32_t a01, a23, a45, a67; 40 | uint32_t b01, b23, b45, b67; 41 | uint32_t d01, d23; 42 | 43 | size_t mat_a_row = K / 2; 44 | size_t mat_b_row = N; 45 | 46 | __half *a_base_ptr = a + (tid % 4) * 2 + (tid / 4) * mat_a_row; 47 | __half *b_base_ptr = b + (tid % 4) * 2 * mat_b_row + tid / 4; 48 | __half *c_base_ptr = c + (tid % 4) * 2 + (tid / 4) * 8; 49 | __half *d_base_ptr = d + (tid % 4) * 2 + (tid / 4) * 8; 50 | 51 | uint32_t *metadata_ptr = metadata_p; 52 | __half *a_ptr = a_base_ptr; 53 | __half *b_ptr = b_base_ptr; 54 | __half *c_ptr = c_base_ptr; 55 | 56 | metadata = metadata_ptr[tid / 2]; 57 | 58 | a01 = *((uint32_t *)a_ptr); 59 | a23 = *((uint32_t *)(a_ptr + 8)); 60 | a45 = *((uint32_t *)(a_ptr + 128)); 61 | a67 = *((uint32_t *)(a_ptr + 128 + 8)); 62 | 63 | uint16_t b0, b1, b2, b3, b4, b5, b6, b7; 64 | 65 | b0 = *((uint16_t *)(b_ptr)); 66 | b1 = *((uint16_t *)(b_ptr + 8)); 67 | b2 = *((uint16_t *)(b_ptr + 64)); 68 | b3 = *((uint16_t *)(b_ptr + 64 + 8)); 69 | b4 = *((uint16_t *)(b_ptr + 64 * 2)); 70 | b5 = *((uint16_t *)(b_ptr + 64 * 2 + 8)); 71 | b6 = *((uint16_t *)(b_ptr + 64 * 3)); 72 | b7 = *((uint16_t *)(b_ptr + 64 * 3 + 8)); 73 | 74 | asm volatile("mov.b32 %0, {%4, %5};\n\t" 75 | "mov.b32 %1, {%6, %7};\n\t" 76 | "mov.b32 %2, {%8, %9};\n\t" 77 | "mov.b32 %3, {%10, %11};\n\t" 78 | : "=r"(b01), "=r"(b23), "=r"(b45), "=r"(b67) 79 | : "h"(b0), "h"(b1), "h"(b2), "h"(b3), "h"(b4), "h"(b5), "h"(b6), 80 | "h"(b7) 81 | :); 82 | 83 | d01 = *((uint32_t *)c_ptr); 84 | d23 = *((uint32_t *)(c_ptr + 64)); 85 | 86 | asm volatile("bar.sync 0;"); 87 | 88 | uint32_t start = 0; 89 | asm volatile("mov.u32 %0, %%clock;" : "=r"(start)::"memory"); 90 | 91 | #pragma unroll 1 92 | for (uint64_t repeat_i = 0; repeat_i < repeat_time; ++repeat_i) { 93 | asm volatile("{\n\t" 94 | "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16\n\t" 95 | "{%0, %1},\n\t" 96 | "{%2, %3, %4, %5},\n\t" 97 | "{%6, %7, %8, %9},\n\t" 98 | "{%10, %11}, %12, 0x0;\n\t" 99 | "}\n\t" 100 | : "=r"(d01), "=r"(d23) 101 | : "r"(a01), "r"(a23), "r"(a45), "r"(a67), "r"(b01), "r"(b23), 102 | "r"(b45), "r"(b67), "r"(d01), "r"(d23), "r"(metadata) 103 | :); 104 | } 105 | 106 | asm volatile("bar.sync 0;"); 107 | 108 | uint32_t end = 0; 109 | asm volatile("mov.u32 %0, %%clock;" : "=r"(end)::"memory"); 110 | 111 | __half *d_ptr = d_base_ptr; 112 | *((uint32_t *)d_ptr) = d01; 113 | *((uint32_t *)(d_ptr + 64)) = d23; 114 | 115 | uint32_t u32_duration = end - start; 116 | if (tid == 0) { 117 | duration[0] = u32_duration; 118 | } 119 | } 120 | 121 | int main(int argc, char **argv) { 122 | size_t mat_a_size = M * K / 2; 123 | size_t mat_b_size = N * K; 124 | size_t mat_c_size = M * N; 125 | size_t mat_d_size = M * N; 126 | size_t metadata_size_bytes = M * 2; // 16 bit per row 127 | 128 | __half *mat_a_host = new __half[mat_a_size]; 129 | __half *mat_b_host = new __half[mat_b_size]; 130 | __half *mat_c_host = new __half[mat_c_size]; 131 | __half *mat_d_host = new __half[mat_d_size]; 132 | uint32_t *metadata_host = 133 | new uint32_t[metadata_size_bytes / sizeof(uint32_t)]; 134 | uint32_t *duration_host = new uint32_t[warp_num]; 135 | 136 | std::ifstream a_fs("a.bin", std::ios_base::binary); 137 | a_fs.read((char *)mat_a_host, mat_a_size * sizeof(__half)); 138 | std::ifstream b_fs("b.bin", std::ios_base::binary); 139 | b_fs.read((char *)mat_b_host, mat_b_size * sizeof(__half)); 140 | std::ifstream c_fs("c.bin", std::ios_base::binary); 141 | c_fs.read((char *)mat_c_host, mat_c_size * sizeof(__half)); 142 | std::ifstream metadata_fs("metadata.bin", std::ios_base::binary); 143 | metadata_fs.read((char *)metadata_host, metadata_size_bytes); 144 | 145 | __half *mat_a_dev; 146 | __half *mat_b_dev; 147 | __half *mat_c_dev; 148 | __half *mat_d_dev; 149 | uint32_t *metadata_dev; 150 | uint32_t *duration_dev; 151 | 152 | gpuErrchk(cudaMalloc(&mat_a_dev, mat_a_size * sizeof(__half))); 153 | gpuErrchk(cudaMalloc(&mat_b_dev, mat_b_size * sizeof(__half))); 154 | gpuErrchk(cudaMalloc(&mat_c_dev, mat_c_size * sizeof(__half))); 155 | gpuErrchk(cudaMalloc(&mat_d_dev, mat_d_size * sizeof(__half))); 156 | gpuErrchk(cudaMalloc(&metadata_dev, metadata_size_bytes)); 157 | gpuErrchk(cudaMalloc(&duration_dev, warp_num * sizeof(uint32_t))); 158 | 159 | // uncomment to use random data, but perfomance may decrease due to power 160 | // limit 161 | /* 162 | gpuErrchk(cudaMemcpy(mat_a_dev, mat_a_host, 163 | mat_a_size * sizeof(__half), cudaMemcpyHostToDevice)); 164 | gpuErrchk(cudaMemcpy(mat_b_dev, mat_b_host, 165 | mat_b_size * sizeof(__half), cudaMemcpyHostToDevice)); 166 | gpuErrchk(cudaMemcpy(mat_c_dev, mat_c_host, 167 | mat_c_size * sizeof(__half), cudaMemcpyHostToDevice)); 168 | */ 169 | gpuErrchk(cudaMemcpy(metadata_dev, 170 | metadata_host, metadata_size_bytes, 171 | cudaMemcpyHostToDevice)); 172 | 173 | sp_mmad_16832_latency_test<<>>( 174 | mat_d_dev, mat_a_dev, mat_b_dev, mat_c_dev, metadata_dev, duration_dev); 175 | gpuErrchk(cudaDeviceSynchronize()); 176 | gpuErrchk(cudaMemcpy(duration_host, duration_dev, warp_num * sizeof(uint32_t), 177 | cudaMemcpyDeviceToHost)); 178 | 179 | double total_cycle = duration_host[0]; 180 | 181 | std::cout << "mma.sp.sync.aligned.m16n8k32 latency: " 182 | << total_cycle / repeat_time << " cycle" << std::endl; 183 | 184 | gpuErrchk(cudaFree(mat_a_dev)); 185 | gpuErrchk(cudaFree(mat_b_dev)); 186 | gpuErrchk(cudaFree(mat_c_dev)); 187 | gpuErrchk(cudaFree(mat_d_dev)); 188 | gpuErrchk(cudaFree(metadata_dev)); 189 | gpuErrchk(cudaFree(duration_dev)); 190 | 191 | delete[] mat_a_host; 192 | delete[] mat_b_host; 193 | delete[] mat_c_host; 194 | delete[] mat_d_host; 195 | delete[] metadata_host; 196 | delete[] duration_host; 197 | 198 | return 0; 199 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 李睿昕 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | --------------------------------------------------------------------------------