├── .gitignore ├── README.md ├── blis.cpp ├── build.sh ├── clean.sh ├── gemm_perf.py ├── references.md ├── results ├── results_float32.png └── results_float64.png ├── tvm_autoscheduler_tune.py ├── tvm_autotvm_tune.py └── tvm_without_tune.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.x 2 | *.log 3 | *.json 4 | *.tmp 5 | results_* 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TVM和BLIS在CPU上做GEMM计算的性能测试 2 | 3 | 本项目分别使用TVM和BLIS做32位和64位的1024x1024x1024的C += AB的矩阵运算,并在单线程、4线程、8线程、12线程和16线程的情况下做了性能测试。 4 | 5 | ## 测试结果 6 | TVM和BLIS单精度GEMM计算性能对比 7 | ![Results32](./results/results_float32.png) 8 | 9 | TVM和BLIS双精度GEMM计算性能对比 10 | ![Results32](./results/results_float64.png) 11 | 12 | ## 测试机器 13 | 测试在一台机械革命笔记本上完成,机器配置参数如下: 14 | * CPU Model: AMD Ryzen 7 4800H with Radeon Graphics 15 | * CPUs: 16 16 | * 主频:Max 2900, Min 1400, CPU调频未关闭 17 | * L1d cache: 256 KiB 18 | * L2 cache: 4 MiB 19 | * L3 cache: 8 MiB 20 | * OS: Ubuntu 20.04.1 64位 21 | 22 | *说明:* 23 | * *tvm_without_tune.py使用TVM的How to optimize GEMM on CPU例子计算方案并使用固定参数运行* 24 | * *tvm_autotvm.py使用TVM的autotvm对tvm_without_tune的计算方案进行参数调优* 25 | * *tvm_autoscheduler.py使用TVM的auto_scheduler自动搜索计算方案和最优参数* 26 | * *blis调用本地编译的blis库进行计算,编译时使用 --enable-cblas -t openmp CC=clang CXX=clang++选项* 27 | * *numpy使用后端MKL支持的conda中的numpy binary进行同样的矩阵运算* 28 | 29 | ## 依赖 30 | 项目依赖 31 | - TVM https://github.com/apache/tvm 32 | - BLIS https://github.com/flame/blis 33 | BLIS编译时候使用如下选项:--enable-cblas -t openmp CC=clang CXX=clang++ 34 | 35 | ## Quick Start 36 | * 参考https://tvm.apache.org/docs/install/index.html安装TVM 37 | * 参考https://github.com/flame/blis/blob/master/docs/BuildSystem.md 安装blis(使用--enable-cblas -t openmp编译选项) 38 | * git clone https://github.com/billmuch/matmul_perf_test 39 | * 使用autotvm调优tvm 单精度 gemm计算代码 40 | python tvm_autotvm_tune.py float32 tune 41 | * 使用autotvm调优tvm 双精度 gemm计算代码 42 | python tvm_autotvm_tune.py float64 tune 43 | * 使用autoscheduler调优tvm 单精度 gemm计算代码 44 | python tvm_autoscheduler_tune.py float32 tune 45 | * 使用autoscheduler调优tvm 双精度 gemm计算代码 46 | python tvm_autoscheduler_tune.py float64 tune 47 | * 编译blis.cpp 48 | ./build.sh 49 | * 测试TVM和BLIS的单精度GEMM性能 50 | python gemm_perf.py float32 51 | * 测试TVM和BLIS的双精度GEMM性能 52 | python gemm_perf.py float64 53 | 54 | ## TVM做GEMM运算代码的调优及运行 55 | - tvm_without_tune.py的运行 56 | python tvm_without_tune.py float32 57 | or 58 | python tvm_without_tune.py float64 59 | 60 | - tvm_autotvm_tune.py的调优 61 | python tvm_autotvm_tune.py float32 tune 62 | or 63 | python tvm_autotvm_tune.py float64 tune 64 | 65 | - tvm_autotvm_tune.py的运行(必须调优后才能运行) 66 | python tvm_autotvm_tune.py float32 67 | or 68 | python tvm_autotvm_tune.py float64 69 | 70 | - tvm_autoscheduler_tune.py的调优 71 | python tvm_autoscheduler_tune.py float32 tune 72 | or 73 | python tvm_autoscheduler_tune.py float64 tune 74 | 75 | - tvm_autoscheduler_tune.py的运行(必须调优后才能运行) 76 | python tvm_autoscheduler_tune.py float32 77 | or 78 | python tvm_autoscheduler_tune.py float64 79 | 80 | 81 | ## BLIS做GEMM运算代码的编译及运行 82 | - 编译 83 | ./build.sh 84 | 85 | - 运行 86 | ./blis_float32.x 87 | or 88 | ./blis_float64.x 89 | 90 | ## 运行性能测试代码 91 | - python gemm_perf.py float32 92 | or 93 | python gemm_perf.py float64 94 | 95 | -------------------------------------------------------------------------------- /blis.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #define MATRIX_FORMAT CblasRowMajor 10 | 11 | #define MDIM 1024 12 | #define NDIM 1024 13 | #define KDIM 1024 14 | 15 | #ifdef FLOAT32 16 | #define BLAS_GEMM cblas_sgemm 17 | #define DTYPE float 18 | #else 19 | #define BLAS_GEMM cblas_dgemm 20 | #define DTYPE double 21 | #endif 22 | 23 | void init_matrix(DTYPE *a, int nrows, int ncols) { 24 | for (int j = 0; j < ncols; j++) { 25 | for (int i = 0; i < nrows; i++) { 26 | a[i + j * nrows] = ((DTYPE) rand() / (DTYPE) RAND_MAX); 27 | } 28 | } 29 | } 30 | 31 | void naive_matmul(const DTYPE *a, const DTYPE *b, DTYPE *c, size_t m, size_t k, size_t n) { 32 | // correctness check 33 | for (size_t i = 0; i < m; i++) { 34 | for (size_t j = 0; j < n; j++) { 35 | size_t ci = i*n + j; 36 | c[ci] = 0.0f; 37 | for (size_t p = 0; p < k; p++) { 38 | c[ci] += a[i*k + p] * b[p*n + j]; 39 | } 40 | } 41 | } 42 | } 43 | 44 | static void BenchmarkFunction() { 45 | DTYPE *A, *B, *C; 46 | 47 | A = (DTYPE *) malloc(MDIM * KDIM * sizeof(DTYPE)); 48 | B = (DTYPE *) malloc(KDIM * NDIM * sizeof(DTYPE)); 49 | C = (DTYPE *) malloc(MDIM * NDIM * sizeof(DTYPE)); 50 | 51 | init_matrix(A, MDIM, KDIM); 52 | init_matrix(B, KDIM, NDIM); 53 | init_matrix(C, MDIM, NDIM); 54 | 55 | int LDA = KDIM; 56 | int LDB = NDIM; 57 | int LDC = NDIM; 58 | DTYPE alpha = 1.0; 59 | DTYPE beta = 0.0; 60 | 61 | BLAS_GEMM(MATRIX_FORMAT, CblasNoTrans, CblasNoTrans, MDIM, NDIM, KDIM, alpha, 62 | A, LDA, B, LDB, beta, C, LDC); 63 | 64 | struct timespec time_start={0, 0},time_end={0, 0}; 65 | clock_gettime(CLOCK_REALTIME, &time_start); 66 | for (int i = 0; i < 500; i++) { 67 | BLAS_GEMM(MATRIX_FORMAT, CblasNoTrans, CblasNoTrans, MDIM, NDIM, KDIM, alpha, 68 | A, LDA, B, LDB, beta, C, LDC); 69 | } 70 | clock_gettime(CLOCK_REALTIME, &time_end); 71 | std::cout << "BLIS duration: " << ((time_end.tv_sec-time_start.tv_sec) * 1000000000.0 + time_end.tv_nsec-time_start.tv_nsec)/500.0/1000000000.0 << std::endl; 72 | 73 | DTYPE *C2 = (DTYPE *) malloc(MDIM * NDIM * sizeof(DTYPE)); 74 | size_t errors = 0; 75 | naive_matmul(A,B,C2,MDIM,KDIM,NDIM); 76 | for (size_t i = 0; i < MDIM; i++) { 77 | for (size_t j = 0; j < NDIM; j++) { 78 | size_t ci = i + j*MDIM; 79 | if (std::abs(C[ci] - C2[ci]) > 0.01f) { 80 | fprintf(stderr, "Incorrect result at index %ld,%ld: C=%0.2f C2=%0.2f\n", i, j, C[ci], C2[ci]); 81 | errors++; 82 | } 83 | } 84 | } 85 | printf("Detected %ld errors.\n", errors); 86 | 87 | free(A); 88 | free(B); 89 | free(C); 90 | } 91 | 92 | int main(int argc, char **argv) { 93 | BenchmarkFunction(); 94 | return 0; 95 | } 96 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | clang++ blis.cpp -DFLOAT32 -fopenmp -O2 -o blis_float32.x -I/home/neal/local/include/blis -L/home/neal/local/lib -lblis 4 | clang++ blis.cpp -DFLOAT64 -fopenmp -O2 -o blis_float64.x -I/home/neal/local/include/blis -L/home/neal/local/lib -lblis 5 | -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm *.o *.x *.tmp -------------------------------------------------------------------------------- /gemm_perf.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import os 3 | import sys 4 | import subprocess 5 | import numpy as np 6 | 7 | os.environ['TVM_LOG_DEBUG']='0' 8 | plt.style.use('ggplot') 9 | 10 | def run_benchmark(path, str_list): 11 | result = [0.0] * len(str_list) 12 | try: 13 | proc = subprocess.Popen(path, shell=True, stdout=subprocess.PIPE) 14 | while True: 15 | line = proc.stdout.readline() 16 | if not line: 17 | break 18 | for i, s in enumerate(str_list): 19 | if line.startswith(bytes(s, encoding = "utf8")): 20 | result[i] = float(line.decode().split(':')[-1]) 21 | except: 22 | return [0.0] * len(str_list) 23 | # print(path, str_list, result) 24 | return result 25 | 26 | def set_num_threads(num): 27 | os.environ['TVM_NUM_THREADS']=str(num) 28 | 29 | os.environ['BLIS_NUM_THREADS']=str(num) 30 | # os.environ['OMP_PLACES']="cores" 31 | # os.environ['OMP_PROC_BIND']="close" 32 | 33 | os.environ['MKL_NUM_THREADS']=str(num) 34 | os.environ['MKL_DYNAMIC']="FALSE" 35 | os.environ['KMP_AFFINITY']="granularity=fine,compact,1,0" 36 | 37 | USAGE = """ 38 | python gemm_perf.py float32 39 | or 40 | python gemm_perf.py float64 41 | """ 42 | def main(argv): 43 | if (len(argv) != 2) or (argv[1] != 'float32' and argv[1] != 'float64'): 44 | print(USAGE) 45 | sys.exit(255) 46 | 47 | dtype = argv[1] 48 | 49 | threads = [] 50 | impls = ['numpy', 'tvm_without_tune', 'tvm_autotvm', 'tvm_autoscheduler', 'blis'] 51 | speeds = {'numpy':[], 'tvm_without_tune':[], 'tvm_autotvm':[], 'tvm_autoscheduler':[], 'blis':[]} 52 | exec = {'numpy':(f'python tvm_without_tune.py {dtype}', 'Numpy'), 53 | 'tvm_without_tune':(f'python tvm_without_tune.py {dtype}', 'TVM'), 54 | 'tvm_autotvm':(f'python tvm_autotvm_tune.py {dtype}', 'TVM'), 55 | 'tvm_autoscheduler':(f'python tvm_autoscheduler_tune.py {dtype}', 'TVM'), 56 | 'blis':(f'./blis_{dtype}.x', 'BLIS') 57 | } 58 | 59 | for num_threads in [1, 4, 8, 12, 16]: 60 | threads.append(num_threads) 61 | set_num_threads(num_threads) 62 | for impl in impls: 63 | speeds[impl].append( 64 | 2.0 * 1024 * 1024 * 1024 / run_benchmark(exec[impl][0], [exec[impl][1]])[0] / 1000000000.0 65 | ) 66 | 67 | with open(f'results_{dtype}.txt', 'w') as rf: 68 | rf.write("threads:") 69 | rf.write(str(threads)) 70 | rf.write('\n') 71 | rf.write("speeds") 72 | rf.write(str(speeds)) 73 | 74 | x = np.arange(len(threads)) 75 | width = 0.1 76 | plt.bar(x - 2.0*width, speeds['numpy'], width, label='numpy') 77 | plt.bar(x - width, speeds['tvm_without_tune'], width, label='tvm_without_tune') 78 | plt.bar(x , speeds['tvm_autotvm'], width, label='tvm_autotvm') 79 | plt.bar(x + width, speeds['tvm_autoscheduler'], width, label='tvm_autoscheduler') 80 | plt.bar(x + 2.0 * width, speeds['blis'], width, label='blis') 81 | plt.ylabel('GFlops') 82 | plt.xlabel('Number of Threads') 83 | plt.title(f'1024x1024x1024 {dtype} gemm perf test on Numpy(MKL), TVM and BLIS') 84 | plt.xticks(x, labels=threads) 85 | plt.legend() 86 | 87 | plt.savefig(f'results_{dtype}.png', dpi=400, bbox_inches='tight') 88 | plt.show() 89 | 90 | if __name__ == '__main__': 91 | main(sys.argv) 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /references.md: -------------------------------------------------------------------------------- 1 | # References: 2 | 3 | 1. Anatomy of High-Performance Matrix Multiplication https://www.cs.utexas.edu/users/flame/pubs/GotoTOMS_final.pdf 4 | 5 | 2. blislab tutorial https://github.com/flame/blislab/blob/master/tutorial.pdf 6 | 7 | 3. How to optimize GEMM on CPU https://tvm.apache.org/docs/how_to/optimize_operators/opt_gemm.html 8 | 9 | 4. TVM AutoTVM 教程 https://tvm.apache.org/docs/tutorial/autotvm_matmul_x86.html 10 | 11 | 5. TVM Auto-scheduling 教程 https://tvm.apache.org/docs/tutorial/auto_scheduler_matmul_x86.html 12 | 13 | 6. How to optimize gemm https://github.com/flame/how-to-optimize-gemm 14 | 15 | 7. 矩阵乘法与 SIMD矩阵乘法与 SIMD 16 | 17 | 8. 单核矩阵乘法性能测试 https://github.com/mmperf/mmperf 18 | 19 | 9. 浮点峰值那些事儿 https://zhuanlan.zhihu.com/p/28226956 20 | 21 | 22 | -------------------------------------------------------------------------------- /results/results_float32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/billmuch/matmul_perf_test/81e6a7ce6c8237f48c5002c831029ce341d258c8/results/results_float32.png -------------------------------------------------------------------------------- /results/results_float64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/billmuch/matmul_perf_test/81e6a7ce6c8237f48c5002c831029ce341d258c8/results/results_float64.png -------------------------------------------------------------------------------- /tvm_autoscheduler_tune.py: -------------------------------------------------------------------------------- 1 | 2 | import tvm 3 | import tvm.testing 4 | from tvm import te, auto_scheduler 5 | import numpy 6 | import timeit 7 | import os 8 | 9 | # The size of the matrix 10 | # (M, K) x (K, N) 11 | M = 1024 12 | K = 1024 13 | N = 1024 14 | 15 | target = "llvm -mcpu=core-avx2" 16 | dev = tvm.device(target, 0) 17 | 18 | EVAL_REPEAT_TIME = 500 19 | 20 | # 参考reference [4] 定义矩阵乘法运算 21 | # 计算C(M, N) = A(M, K) x B(K, N) 22 | @auto_scheduler.register_workload 23 | def matmul(M, N, K, dtype): 24 | A = te.placeholder((M, K), name="A", dtype=dtype) 25 | B = te.placeholder((K, N), name="B", dtype=dtype) 26 | 27 | k = te.reduce_axis((0, K), name="k") 28 | C = te.compute( 29 | (M, N), 30 | lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), 31 | name="C", 32 | attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B 33 | ) 34 | 35 | return [A, B, C] 36 | 37 | def autotune(M, N, K, dtype, target_name, log_file): 38 | print(target_name) 39 | target = tvm.target.Target(target_name) 40 | task = tvm.auto_scheduler.SearchTask(func=matmul, args=(M, N, K, dtype), target=target) 41 | 42 | # Inspect the computational graph 43 | print("Computational DAG:") 44 | print(task.compute_dag) 45 | 46 | tune_option = None 47 | measure_ctx = None 48 | tune_option = auto_scheduler.TuningOptions( 49 | num_measure_trials=6000, 50 | measure_callbacks=[auto_scheduler.RecordToFile(log_file)], 51 | verbose=2, 52 | ) 53 | 54 | task.tune(tune_option) 55 | sch, args = task.apply_best(log_file) 56 | 57 | # 检查矩阵乘法结果是否正确,并返回乘法函数 58 | def get_matmul_func(M, N, K, dtype, target_name, log_file): 59 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 60 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 61 | 62 | answer = numpy.dot(a.numpy(), b.numpy()) 63 | 64 | target = tvm.target.Target(target_name) 65 | task = tvm.auto_scheduler.SearchTask(func=matmul, args=(M, N, K, dtype), target=target) 66 | sch, args = task.apply_best(log_file) 67 | func = tvm.build(sch, args, target=target, name="matmul") 68 | assert func 69 | 70 | # print(tvm.lower(sch, args, simple_mode=True)) 71 | # print(func.get_source("asm")) 72 | # func.export_library("tvm_autoscheduler.so") 73 | 74 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 75 | func(a, b, c) 76 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 77 | 78 | return func 79 | 80 | def benchmark(matmul_func, dtype): 81 | # Random generated tensor for testing 82 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 83 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 84 | 85 | np_repeat = EVAL_REPEAT_TIME 86 | np_runing_time = timeit.timeit( 87 | setup="import numpy\n" 88 | "M = " + str(M) + "\n" 89 | "K = " + str(K) + "\n" 90 | "N = " + str(N) + "\n" 91 | 'dtype = "' + str(dtype) + '"\n' 92 | "a = numpy.random.rand(M, K).astype(dtype)\n" 93 | "b = numpy.random.rand(K, N).astype(dtype)\n", 94 | stmt="answer = numpy.dot(a, b)", 95 | number=np_repeat, 96 | ) 97 | print("Numpy running time: %f" % (np_runing_time / np_repeat)) 98 | 99 | answer = numpy.dot(a.numpy(), b.numpy()) 100 | 101 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 102 | matmul_func(a, b, c) 103 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 104 | 105 | evaluator = matmul_func.time_evaluator(matmul_func.entry_name, dev, number=EVAL_REPEAT_TIME) 106 | print("TVM autoscheduler tuned: %f" % evaluator(a, b, c).mean) 107 | 108 | def main(argv): 109 | if (len(argv) > 1 and argv[1] == 'float32'): 110 | dtype = "float32" 111 | log_file = "matmul_autoscheduler_32.json" 112 | else: 113 | dtype = "float64" 114 | log_file = "matmul_autoscheduler_64.json" 115 | 116 | if (len(argv) == 3 and argv[2] == 'tune'): 117 | autotune(M, N, K, dtype, target, log_file) 118 | 119 | func = get_matmul_func(M, N, K, dtype, target, log_file) 120 | benchmark(func, dtype) 121 | 122 | import sys 123 | if __name__ == '__main__': 124 | main(sys.argv) 125 | 126 | -------------------------------------------------------------------------------- /tvm_autotvm_tune.py: -------------------------------------------------------------------------------- 1 | 2 | import tvm 3 | import tvm.testing 4 | from tvm import te 5 | import numpy 6 | import timeit 7 | 8 | from tvm import te, autotvm, auto_scheduler 9 | import os 10 | import sys 11 | import logging 12 | 13 | # The size of the matrix 14 | # (M, K) x (K, N) 15 | M = 1024 16 | K = 1024 17 | N = 1024 18 | 19 | target = "llvm -mcpu=core-avx2" 20 | dev = tvm.device(target, 0) 21 | 22 | EVAL_REPEAT_TIME = 500 23 | 24 | # 参考reference [2][3] 定义矩阵乘法运算和调参设置 25 | # 计算C(M, N) = A(M, K) x B(K, N) 26 | @autotvm.template("tutorial/matmul") 27 | def matmul(M, N, K, dtype): 28 | # Algorithm 29 | k = te.reduce_axis((0, K), "k") 30 | A = te.placeholder((M, K), name="A", dtype=dtype) 31 | B = te.placeholder((K, N), name="B", dtype=dtype) 32 | 33 | cfg = autotvm.get_config() 34 | cfg.define_split("tile_x", M, num_outputs=2) 35 | cfg.define_split("tile_y", N, num_outputs=2) 36 | cfg.define_split("tile_k", K, num_outputs=2) 37 | 38 | bn = cfg["tile_y"].size[-1] 39 | packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB') 40 | C = te.compute((M, N), 41 | lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k), 42 | name = 'C') 43 | s = te.create_schedule(C.op) 44 | 45 | CC = s.cache_write(C, "global") 46 | 47 | mo, mi = cfg["tile_x"].apply(s, C, C.op.axis[0]) 48 | no, ni = cfg["tile_y"].apply(s, C, C.op.axis[1]) 49 | s[C].reorder(mo, no, mi, ni) 50 | 51 | s[CC].compute_at(s[C], no) 52 | 53 | mc, nc = s[CC].op.axis 54 | (kaxis,) = s[CC].op.reduce_axis 55 | ko, ki = cfg["tile_k"].apply(s, CC, kaxis) 56 | 57 | cfg.define_reorder("reorder", [mc, ki, nc], "all") 58 | cfg["reorder"].apply(s, CC, [mc, ki, nc]) 59 | cfg.define_annotate('ann', [mc, ki, nc], policy='try_unroll_vec') 60 | cfg['ann'].apply(s, CC, [mc, ki, nc]) 61 | 62 | # parallel 63 | s[C].parallel(mo) 64 | s[C].unroll(mi) 65 | s[C].vectorize(ni) 66 | 67 | bigN, _, littleN = s[packedB].op.axis 68 | s[packedB].vectorize(littleN) 69 | s[packedB].parallel(bigN) 70 | 71 | return s, [A, B, C] 72 | 73 | def tune_matmul(dtype, log_file): 74 | if (dtype == 'float32'): 75 | log_tmp_file = "matmul_autotvm_32.log.tmp" 76 | else: 77 | log_tmp_file = "matmul_autotvm_64.log.tmp" 78 | 79 | task = autotvm.task.create("tutorial/matmul", args=(N, K, M, dtype), target=target) 80 | print(task.config_space) 81 | 82 | # logging config (for printing tuning log to the screen) 83 | logging.getLogger("autotvm").setLevel(logging.DEBUG) 84 | logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout)) 85 | 86 | 87 | measure_option = autotvm.measure_option(builder="local", runner=autotvm.LocalRunner(number=5)) 88 | 89 | # Begin tuning with RandomTuner, log records to file `matmul_autotvm.log` 90 | # You can use alternatives like XGBTuner. 91 | 92 | # begin tuning, log records to file `matmul_autotvm.log` 93 | # tuner = autotvm.tuner.GridSearchTuner(task) 94 | tuner = autotvm.tuner.XGBTuner(task) 95 | n_trial = 6000 96 | early_stopping = 800 97 | if os.path.exists(log_tmp_file): 98 | os.remove(log_tmp_file) 99 | tuner.tune(n_trial=n_trial, 100 | early_stopping=early_stopping, 101 | measure_option=measure_option, 102 | callbacks=[autotvm.callback.progress_bar(n_trial), 103 | autotvm.callback.log_to_file(log_tmp_file)]) 104 | 105 | # pick best records to a cache file 106 | autotvm.record.pick_best(log_tmp_file, log_file) 107 | 108 | 109 | # 检查矩阵乘法结果是否正确,并返回乘法函数 110 | def get_matmul_func(dtype, log_file): 111 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 112 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 113 | 114 | answer = numpy.dot(a.numpy(), b.numpy()) 115 | 116 | with autotvm.apply_history_best(log_file): 117 | with tvm.target.Target(target): 118 | s, arg_bufs = matmul(N, K, M, dtype) 119 | func = tvm.build(s, arg_bufs, target=target, name="matmul") 120 | assert func 121 | 122 | # print(tvm.lower(s, arg_bufs, simple_mode=True)) 123 | # print(func.get_source("asm")) 124 | 125 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 126 | func(a, b, c) 127 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 128 | 129 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 130 | func(a, b, c) 131 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 132 | 133 | return func 134 | 135 | def benchmark(matmul_func, dtype): 136 | # Random generated tensor for testing 137 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 138 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 139 | 140 | np_repeat = EVAL_REPEAT_TIME 141 | np_runing_time = timeit.timeit( 142 | setup="import numpy\n" 143 | "M = " + str(M) + "\n" 144 | "K = " + str(K) + "\n" 145 | "N = " + str(N) + "\n" 146 | 'dtype = "' + str(dtype) + '"\n' 147 | "a = numpy.random.rand(M, K).astype(dtype)\n" 148 | "b = numpy.random.rand(K, N).astype(dtype)\n", 149 | stmt="answer = numpy.dot(a, b)", 150 | number=np_repeat, 151 | ) 152 | print("Numpy running time: %f" % (np_runing_time / np_repeat)) 153 | 154 | answer = numpy.dot(a.numpy(), b.numpy()) 155 | 156 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 157 | matmul_func(a, b, c) 158 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 159 | 160 | evaluator = matmul_func.time_evaluator(matmul_func.entry_name, dev, number=EVAL_REPEAT_TIME) 161 | print("TVM autotvm tuned: %f" % evaluator(a, b, c).mean) 162 | 163 | def main(argv): 164 | if (len(argv) > 1 and argv[1] == 'float32'): 165 | dtype = "float32" 166 | log_file = "matmul_autotvm_32.log" 167 | else: 168 | dtype = "float64" 169 | log_file = "matmul_autotvm_64.log" 170 | 171 | if (len(argv) == 3 and argv[2] == 'tune'): 172 | tune_matmul(dtype, log_file) 173 | 174 | func = get_matmul_func(dtype, log_file) 175 | benchmark(func, dtype) 176 | 177 | import sys 178 | if __name__ == '__main__': 179 | main(sys.argv) -------------------------------------------------------------------------------- /tvm_without_tune.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | import tvm.testing 3 | from tvm import te 4 | import numpy 5 | import timeit 6 | 7 | # The size of the matrix 8 | # (M, K) x (K, N) 9 | M = 1024 10 | K = 1024 11 | N = 1024 12 | 13 | target = "llvm -mcpu=core-avx2" 14 | dev = tvm.device(target, 0) 15 | 16 | # 参考reference [1] opt1到opt6,定义的矩阵乘法 17 | # 计算C(M, N) = A(M, K) x B(K, N) 18 | def matmul(M, N, K, dtype): 19 | # Algorithm 20 | k = te.reduce_axis((0, K), "k") 21 | A = te.placeholder((M, K), name="A", dtype=dtype) 22 | B = te.placeholder((K, N), name="B", dtype=dtype) 23 | 24 | bn = 32 25 | kfactor = 4 26 | 27 | packedB = te.compute( 28 | (N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB" 29 | ) 30 | C = te.compute( 31 | (M, N), 32 | lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k), 33 | name="C", 34 | ) 35 | s = te.create_schedule(C.op) 36 | CC = s.cache_write(C, "global") 37 | 38 | mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) 39 | s[CC].compute_at(s[C], no) 40 | 41 | mc, nc = s[CC].op.axis 42 | (kaxis,) = s[CC].op.reduce_axis 43 | ko, ki = s[CC].split(kaxis, factor=kfactor) 44 | s[CC].reorder(ko, mc, ki, nc) 45 | s[CC].vectorize(nc) 46 | s[CC].unroll(ki) 47 | 48 | # parallel 49 | s[C].parallel(mo) 50 | 51 | bigN, _, littleN = s[packedB].op.axis 52 | s[packedB].vectorize(littleN) 53 | s[packedB].parallel(bigN) 54 | 55 | return s, [A, B, C] 56 | 57 | # 检查矩阵乘法结果是否正确,并返回乘法函数 58 | def get_matmul_func(sch, args, dtype): 59 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 60 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 61 | 62 | answer = numpy.dot(a.numpy(), b.numpy()) 63 | 64 | opt_level = 3 65 | with tvm.transform.PassContext(opt_level=opt_level): 66 | func = tvm.build(sch, args, target=target, name="mmult") 67 | assert func 68 | 69 | # print(tvm.lower(sch, args, simple_mode=True)) 70 | # print(func.get_source("asm")) 71 | 72 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 73 | func(a, b, c) 74 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 75 | 76 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 77 | func(a, b, c) 78 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 79 | 80 | return func 81 | 82 | def benchmark(matmul_func, dtype): 83 | # Random generated tensor for testing 84 | a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev) 85 | b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev) 86 | 87 | np_repeat = 500 88 | np_runing_time = timeit.timeit( 89 | setup="import numpy\n" 90 | "M = " + str(M) + "\n" 91 | "K = " + str(K) + "\n" 92 | "N = " + str(N) + "\n" 93 | 'dtype = "' + str(dtype) + '"\n' 94 | "a = numpy.random.rand(M, K).astype(dtype)\n" 95 | "b = numpy.random.rand(K, N).astype(dtype)\n", 96 | stmt="answer = numpy.dot(a, b)", 97 | number=np_repeat, 98 | ) 99 | print("Numpy running time: %f" % (np_runing_time / np_repeat)) 100 | 101 | answer = numpy.dot(a.numpy(), b.numpy()) 102 | 103 | c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev) 104 | matmul_func(a, b, c) 105 | tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5) 106 | 107 | evaluator = matmul_func.time_evaluator(matmul_func.entry_name, dev, number=500) 108 | print("TVM without tune: %f" % evaluator(a, b, c).mean) 109 | 110 | def main(argv): 111 | if (len(argv) > 1 and argv[1] == 'float32'): 112 | dtype = "float32" 113 | else: 114 | dtype = "float64" 115 | sch, args = matmul(M, N, K, dtype) 116 | func = get_matmul_func(sch, args, dtype) 117 | benchmark(func, dtype) 118 | # print(tvm.lower(sch, args, simple_mode=True)) 119 | 120 | import sys 121 | if __name__ == '__main__': 122 | main(sys.argv) --------------------------------------------------------------------------------