├── test_equivalence.py ├── utils.py ├── autotune_cpu_random_int8.cpu ├── autotune_cpu_random.sh ├── generate_test_matrix.py ├── README.md ├── code_fragments.py ├── fastsparse.py ├── driver_cpu.cpp ├── LICENSE ├── code_gen_cpu_int8_block.py ├── code_gen_cpu_int8.py ├── code_gen_cpu.py └── code_gen_cpu_conv.py /test_equivalence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | a = np.load(sys.argv[1]) 4 | b = np.load(sys.argv[2]) 5 | print("Difference: ",np.sum(np.abs(a.squeeze()-b.squeeze()))) 6 | print(np.where(np.abs(a-b) > 0.01)) 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import numpy as np 3 | EPS = 0.0000001 4 | ST = 1 5 | 6 | def dec_to_2s(number, width): 7 | if number > 0: 8 | result = bin(number) 9 | pad_l = width - (len(result) - 2) 10 | if pad_l < 0: 11 | return "Error not enough bits" 12 | else: 13 | return "0" * pad_l + result[2:] 14 | elif number == 0: 15 | return "0" * width 16 | else: 17 | return dec_to_2s(number + 2 ** width, width) 18 | 19 | def float_to_hex(number): 20 | s = struct.pack('>f',number) 21 | bits = struct.unpack('>l',s)[0] 22 | return hex(int(dec_to_2s(bits,32),2)).upper().replace("X","f") 23 | 24 | def half_to_hex(number): 25 | s = hex(np.float16(number).view('H'))[2:].zfill(4) 26 | return "0x" + s + s 27 | 28 | def hex_to_bin(hex_string): 29 | scale = 16 ## equals to hexadecimal 30 | num_of_bits = 16 31 | return bin(int(hex_string, scale))[2:].zfill(num_of_bits) 32 | 33 | def bin_to_half(bin_string): 34 | sign = int(bin_string[0]) 35 | exponent = int(bin_string[1:6],2) 36 | mantissa = int(bin_string[6:],2) 37 | factor = 10 ** np.ceil(np.log10(mantissa)) 38 | return (-1) ** sign * 2 ** (exponent - 15) * (1 + mantissa /factor) 39 | -------------------------------------------------------------------------------- /autotune_cpu_random_int8.cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | A_dim=$1 4 | B_dim=$2 5 | C_dim=$3 6 | infile=matrix_transposed.npy 7 | BLOCK=1 8 | python generate_test_matrix.py $A_dim $B_dim $C_dim 1 $BLOCK 9 | biasfile=bias.npy 10 | C_blocks=1 11 | Gy=1 12 | besttime=100000 13 | for AT in $BLOCK; do 14 | for B_blocks in 1; do 15 | for CT in 4; do 16 | 17 | 18 | python code_gen_cpu_int8_block.py --A_dim $A_dim --B_dim $B_dim --C_dim $C_dim --AT $AT --CT $CT --B_blocks $B_blocks --C_blocks $C_blocks --infile $infile --outfile testing.cpp --outfile_asm test1.s --x86 --infile_bias $biasfile 19 | icc -march=icelake -fPIC -shared -g test1.s -o test.so 20 | icc -I . -O3 -march=native -D AT=$AT -D CT=$1 -D C_Blocks=$C_blocks -DA_dim=$A_dim -DINFILE=$infile -D B_dim=$B_dim -D C_dim=$C_dim -D C_blocks=$C_blocks -D X86=1 -D MULTI=0 -D INT8=1 driver_cpu.cpp -lcnpy -o test -std=c++17 21 | /home/ziheng/Downloads/sde-external-8.56.0-2020-07-05-lin/sde64 -- ./test #> runtime 22 | python test_equivalence.py cpu_output.npy ref.npy 23 | #runtime=$(grep "millisecond" runtime | awk '{print $3}') 24 | #cat runtime 25 | #if (( $(echo "$runtime < $besttime" | bc -l) )) ; then 26 | # besttime=$runtime 27 | # bestset=${AT}_${B_blocks}_${CT} 28 | #fi 29 | 30 | done 31 | done 32 | 33 | done 34 | echo Best Runtime $besttime 35 | echo $bestset 36 | -------------------------------------------------------------------------------- /autotune_cpu_random.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | A_dim=$1 4 | B_dim=$2 5 | C_dim=$3 6 | mode=$4 7 | infile=matrix_transposed.npy 8 | python generate_test_matrix.py $A_dim $B_dim $C_dim $mode 9 | 10 | C_blocks=4 11 | Gy=1 12 | besttime=100000 13 | for AT in 1 2 4; do 14 | for B_blocks in 4 8 16; do 15 | for CT in 1 2 4; do 16 | 17 | if [ $(($AT * $CT)) -gt 12 ]; then 18 | continue 19 | fi 20 | python code_gen_cpu.py --A_dim $A_dim --B_dim $B_dim --C_dim $C_dim --AT $AT --CT $CT --B_blocks $B_blocks --C_blocks $C_blocks --Gy $Gy --infile $infile --outfile testing.cpp --outfile_asm test1.s --x86 --no_relu --infile_bias bias.npy --fuse 21 | #icc -fopenmp -shared -fPIC -O3 -march=native testing.cpp -o test1.s -S 22 | gcc -shared -g test1.s -o test.so 23 | icc -fopenmp -I . -mkl -O3 -march=native -D AT=$AT -D CT=$1 -D C_Blocks=$C_blocks -DA_dim=$A_dim -DINFILE=$infile -D B_dim=$B_dim -D C_dim=$C_dim -D C_blocks=$C_blocks -D X86=1 -D MULTI=0 driver_cpu.cpp -lcnpy -o test -std=c++17 24 | ./test > runtime 25 | python test_equivalence.py cpu_output.npy ref.npy 26 | runtime=$(grep "millisecond" runtime | awk '{print $3}') 27 | cat runtime 28 | if (( $(echo "$runtime < $besttime" | bc -l) )) ; then 29 | besttime=$runtime 30 | bestset=${AT}_${B_blocks}_${CT} 31 | fi 32 | 33 | done 34 | done 35 | 36 | done 37 | echo Best Runtime $besttime 38 | echo $bestset 39 | -------------------------------------------------------------------------------- /generate_test_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | M = int(sys.argv[1]) 5 | K = int(sys.argv[2]) 6 | N = int(sys.argv[3]) 7 | int8 = True if int(sys.argv[4]) == 1 else False 8 | if int8: 9 | #scale = np.random.normal(size=(M)).astype(np.float32) 10 | scale = np.ones((M)).astype(np.float32) 11 | else: 12 | scale = 1 13 | bias = np.random.normal(size=(M)) 14 | #bias = np.zeros((M)) 15 | AB = 1 + np.abs(np.random.normal(size=(K,M)).astype(np.float32) * 3) 16 | 17 | BLOCK = int(sys.argv[5]) 18 | locs = [i for i in range(M* K) if i %BLOCK == 0] 19 | zero_locs = np.random.choice(M*K//BLOCK, M * K // BLOCK // 10 * 9,replace=False) * BLOCK 20 | 21 | for i in range(BLOCK): 22 | indices0 = np.unravel_index(zero_locs + i,(K,M)) 23 | AB[indices0] = 0 24 | # 25 | AB = AB.transpose().copy() 26 | 27 | #mask = (AB > 0) * 3 28 | #AB = AB - mask 29 | #print(AB) 30 | #AB = AB * (AB > 2.7) 31 | #BC = np.random.normal(size=(K,N)).astype(np.float32) 32 | BC = np.ones((K,N)).astype(np.float32) 33 | for i in range(K): 34 | BC[i] = np.random.randint(2) 35 | 36 | if int8: 37 | AB = AB.astype(np.int8) 38 | BC = BC.astype(np.uint8) 39 | bias = bias.astype(np.int32) 40 | 41 | AB = -AB 42 | 43 | 44 | print("density",np.count_nonzero(AB) / M/ K) 45 | AC = np.dot(AB,BC) + np.expand_dims(bias,1) 46 | AC = AC.astype(np.float32) * np.expand_dims(scale,1) 47 | 48 | if int8: 49 | #AC = AC.astype(np.int32) 50 | AC = AC.astype(np.int8) 51 | if int8: 52 | np.save("bias.npy",bias) 53 | else: 54 | np.save("bias.npy",bias.astype(np.float32)) 55 | np.save("matrix.npy",AB) 56 | if int8: 57 | np.save("scale.npy",scale) 58 | np.save("matrix_transposed.npy",AB.transpose()) 59 | np.save("BC.npy",BC) 60 | np.save("ref.npy",AC ) 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPARSEDNN 2 | 3 | **If you want to use this repo, please send me an email: zihengw@stanford.edu, or raise a Github issue. 4 | ** 5 | 6 | **Sparse INT8 knerels are here. 7 | ** 8 | 9 | Fast sparse deep learning on CPUs. This is the kernel library generator described in the paper: https://arxiv.org/abs/2101.07948 10 | My other repo on sparse deep learning on GPUs: https://github.com/marsupialtail/gpu-sparsert. Will merge at some point when I'm feeling less lazy. 11 | 12 | Python API: python fastsparse.py. Minimal required dependencies. Should work anywhere. 13 | 14 | C++ API: check out driver_cpu.cpp, or run autotune_cpu_random.sh 128 128 128 0. This requires cnpy to read numpy files, so make sure that you can link to cnpy. 15 | C++ API only: for block sparse int8 matrix multiply, run autotune_cpu_random_int8.cpu 512 512 128. 16 | 17 | Python API has some bad overhead due to using ctypes. This is noticeable for smaller matrices but not really noticeable for large matrices. The benchmarkings done in the Arxiv paper was all done with the C++ API. 18 | 19 | **Work that is not yet open sourced: kernel generator for sparse convolutions (as described in the Arxiv paper) using implicit convolution, lightweight inference engine to get end-to-end results. If interested in any of this please email.** 20 | 21 | FAQs: 22 | 1) How does this compare to Neuralmagic? Last time I checked the deepsparse library does not allow you to run kernel-level benchmarks. If you care about end to end neural network acceleration, you should definitely go with Neuralmagic if they happen to support your model. 23 | 2) Future work? This is not exactly along the lines of my PhD thesis so I work on this sparingly. If you want to contribute to this repo you could make a Pytorch or Tensorflow custom op with the Python or C++ API. However it's unclear how gradients would work, and you will have to compile this op with the fixed sparsity pattern, something that the current Pytorch/Tensorflow frameworks might not support that well. 24 | -------------------------------------------------------------------------------- /code_fragments.py: -------------------------------------------------------------------------------- 1 | START_NONFUSED=""" 2 | #include 3 | #include "mkl.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | // we are doing AC = AB * BC, reduce across the B dimension 16 | // binding B to the x dimension, A to the y dimension and C to the z dimension 17 | 18 | #define Tsy 1 19 | #define Tsz (C_dim / C_BLOCKS) 20 | #define ST 1 21 | #define Fx 1 22 | #define Fy (Tsz/Fx) 23 | 24 | //#define 64 (64 / 1 / Tsy) 25 | 26 | #define Usy (Tsy * Fy) 27 | #define Gsy Usy 28 | 29 | #define Gy 1 30 | #define Block_size (Gy * Gsy) 31 | #define X86 X86_DEF 32 | #define ARM ARM_DEF 33 | #include 34 | #include 35 | 36 | 37 | struct thread_data { 38 | const float * __restrict__ AB_val; 39 | const float * __restrict__ AB_bias; 40 | const float * __restrict__ BC; 41 | float * AC; 42 | int start; 43 | int end; 44 | }; 45 | 46 | void * mm(void * threadarg) 47 | { 48 | struct thread_data *my_data = (struct thread_data * ) threadarg; 49 | const float * __restrict__ AB_val = my_data->AB_val; 50 | const float * __restrict__ AB_bias = my_data->AB_bias; 51 | const float * __restrict__ BC = my_data->BC; 52 | float * AC = my_data->AC; 53 | int start = my_data->start; 54 | int end = my_data->end; 55 | 56 | #if X86 57 | __m256 ACC[Ny]; 58 | __m256 RC, val; 59 | #elif ARM 60 | float32x4_t ACC[Ny]; 61 | float32x4_t RC, val; 62 | #endif 63 | __m256 zero = _mm256_setzero_ps(); 64 | // #pragma omp parallel for schedule(static) private(ACC,RC,val,zero) 65 | 66 | for(int C_block = start; C_block < end; C_block ++){ 67 | 68 | int C_offset = C_block * (C_dim / C_BLOCKS); 69 | 70 | 71 | """ 72 | 73 | 74 | 75 | BLOCK_CONTROL_START= """ 76 | #if X86 77 | for(int j=0; j < Ny; j++) 78 | { 79 | ACC[j] = _mm256_setzero_ps(); 80 | } 81 | 82 | #pragma vector aligned 83 | for(int lane =0; lane < Tsz; lane += 8){ 84 | #elif ARM 85 | for(int j=0; j < Ny; j++) 86 | { 87 | ACC[j] = vdupq_n_f32(0.0f); 88 | } 89 | 90 | for(int lane =0; lane < Tsz; lane += 4){ 91 | #endif 92 | """ 93 | 94 | BLOCK_END_REDUCTION=""" 95 | #if X86 96 | 97 | _mm256_store_ps(&AC[OFFSET + C_offset + lane],_mm256_max_ps(zero,_mm256_add_ps(ACC[IDX] , _mm256_broadcast_ss(AB_bias + BIAS)))); 98 | ACC[IDX] = _mm256_setzero_ps(); 99 | 100 | 101 | #elif ARM 102 | 103 | vst1q_f32(&AC[OFFSET + C_offset + lane], ACC[IDX]); 104 | ACC[IDX] = vdupq_n_f32(0.0f); 105 | 106 | #endif 107 | """ 108 | 109 | BLOCK_END = """ 110 | #if X86 111 | #pragma vector aligned 112 | for(int i =0; i < Ny; i++) 113 | { 114 | _mm256_store_ps(&AC[(A_offset + i) * C_dim + C_offset + lane],ACC[i]); 115 | ACC[i] = _mm256_setzero_ps(); 116 | } 117 | } 118 | #elif ARM 119 | for(int i =0; i < Ny; i++) 120 | { 121 | vst1q_f32(&AC[(A_offset + i) * C_dim + C_offset + lane], ACC[i]); 122 | ACC[i] = vdupq_n_f32(0.0f); 123 | } 124 | } 125 | #endif 126 | 127 | 128 | """ 129 | 130 | END_NONFUSED = """ 131 | 132 | } 133 | //pthread_exit(NULL); 134 | } 135 | 136 | """ 137 | -------------------------------------------------------------------------------- /fastsparse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from ctypes import * 4 | import time 5 | 6 | 7 | class Input(Structure): 8 | _fields_ = [ 9 | ("AB_vals",c_void_p), 10 | ("AB_bias",c_void_p), 11 | ("BC",c_void_p), 12 | ("AC",c_void_p), 13 | ("start",c_int32), 14 | ("end",c_int32) 15 | ] 16 | 17 | 18 | class SpMM: 19 | 20 | # takes in a numpy sparse matrix in the dense array format with 0s, and the C_dimension of the dense matrix 21 | def __init__(self, matrix, C_dim, bias=None): 22 | self.A_dim = matrix.shape[0] 23 | self.B_dim = matrix.shape[1] 24 | self.matrix = matrix 25 | self.C_dim = C_dim 26 | self.bias = bias 27 | if self.bias is not None: 28 | assert len(self.bias) == self.A_dim 29 | 30 | def compile(self,name = "spmm", val_name = "vals.npy", bias_name = "bias.npy", AT = 6, CT = 2, B_blocks = 1, C_blocks = 1, no_relu=True,epi="NONE"): 31 | 32 | import code_gen_cpu 33 | if not "avx2" in open("/proc/cpuinfo","r").read(): 34 | print("We need at least AVX2.") 35 | raise Exception 36 | if "avx512" in open("/proc/cpuinfo","r").read(): 37 | code_gen_cpu.AVX512 = True 38 | code_gen_cpu.VEC = 16 39 | else: 40 | code_gen_cpu.AVX512 = False 41 | code_gen_cpu.VEC = 8 42 | 43 | code_gen_cpu.FUNC_NAME = name 44 | code_gen_cpu.EPI = epi 45 | code_gen_cpu.IN_FORMAT = "NCHW" 46 | code_gen_cpu.OUT_FORMAT = "NCHW" 47 | code_gen_cpu.GY = 1 48 | code_gen_cpu.FUSE_END = False 49 | code_gen_cpu.NO_RELU = no_relu 50 | code_gen_cpu.A_dim = self.A_dim 51 | code_gen_cpu.B_dim = self.B_dim 52 | code_gen_cpu.C_dim = self.C_dim 53 | code_gen_cpu.AT = AT 54 | code_gen_cpu.CT = CT 55 | code_gen_cpu.B_blocks = B_blocks 56 | code_gen_cpu.C_blocks = C_blocks 57 | code_gen_cpu.outfile = "out.cpp" 58 | code_gen_cpu.outfile_asm = "out.s" 59 | code_gen_cpu.bias = self.bias 60 | assert self.C_dim % C_blocks == 0 61 | 62 | code_gen_cpu.TSZ = self.C_dim // C_blocks if self.C_dim % C_blocks == 0 else self.C_dim // C_blocks + 1 63 | code_gen_cpu.X86 = True 64 | code_gen_cpu.ARM = False 65 | NRS = False 66 | 67 | BA = self.matrix.transpose() 68 | #print(BA.shape) 69 | BA = BA.squeeze() 70 | 71 | code_gen_cpu.AB_vals = [] 72 | code_gen_cpu.A_idx = [] 73 | code_gen_cpu.B_idx = [] 74 | code_gen_cpu.AB_block_offs = [0] 75 | #global off 76 | code_gen_cpu.off = 0 77 | 78 | """ 79 | We are going to redo BA here to remove some empty rows 80 | """ 81 | 82 | nnz_cols = np.unique(np.where(BA)[1]) 83 | code_gen_cpu.mapping = {i : nnz_cols[i] for i in range(len(nnz_cols))} 84 | #print(mapping) 85 | BA = BA[:,nnz_cols] 86 | code_gen_cpu.A_dim = len(nnz_cols) 87 | 88 | if code_gen_cpu.A_dim % AT == 0: 89 | A_blocks = code_gen_cpu.A_dim // AT 90 | else: 91 | A_blocks = code_gen_cpu.A_dim // AT + 1 92 | 93 | code_gen_cpu.gencode(BA,self.C_dim,A_blocks,C_blocks,name="bump") 94 | 95 | self.AB_vals = np.array(code_gen_cpu.AB_vals) 96 | np.save(val_name,np.array(self.AB_vals)) 97 | if self.bias is not None: 98 | np.save(bias_name,np.array(self.bias)) 99 | else: 100 | self.bias = np.ones((self.A_dim)) 101 | #np.save(bias_name,np.array(self.bias)) 102 | os.system("gcc -c out.s") 103 | os.system("ar rvs " + name + ".a out.o >/dev/null 2>&1") 104 | os.system("gcc -shared out.s -o " + name + ".so ") 105 | os.system("rm out.o out.s out.cpp") 106 | self.libc = CDLL(name + ".so") 107 | 108 | def load(self,sl_name, vec_name, bias_name = None): 109 | self.libc = CDLL(sl_name) 110 | self.AB_vals = np.load(vec_name) 111 | assert self.AB_vals.dtype == np.float32 112 | if bias_name: 113 | self.bias = np.load(bias_name) 114 | else: 115 | # we will not be using the values in the kernel anyways 116 | self.bias = np.ones((self.A_dim)) 117 | assert len(self.bias) == self.A_dim 118 | 119 | 120 | 121 | def run(self,BC): 122 | self.AC = np.empty((self.A_dim,self.C_dim),dtype=np.float32) 123 | w = self.AC.ctypes.data 124 | z = BC.ctypes.data 125 | x = self.AB_vals.ctypes.data 126 | AB_bias = self.bias 127 | y = AB_bias.ctypes.data 128 | self.arg = pointer(Input(x,y,z,w,0,1)) 129 | self.libc._spmm(self.arg) 130 | return self.AC 131 | 132 | def ref_run(self,BC): 133 | return np.dot(self.matrix,BC).astype(np.float32) 134 | 135 | 136 | a = np.load("matrix.npy") 137 | b = SpMM(a,128) 138 | b.compile() 139 | test_input = np.random.normal(size=(128,128)).astype(np.float32) 140 | b.run(test_input) 141 | reference = b.ref_run(test_input) 142 | assert np.abs(np.sum(np.sum(b.AC-reference))) < 0.1 143 | # 144 | start = time.time() 145 | for i in range(1000): 146 | b.run(test_input) 147 | print((time.time()-start) * 1000) 148 | 149 | -------------------------------------------------------------------------------- /driver_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | // we are doing AC = AB * BC, reduce across the B dimension 14 | // binding B to the x dimension, A to the y dimension and C to the z dimension 15 | #include 16 | //#define 64 (64 / 1 / Tsy) 17 | #include 18 | #include 19 | using namespace std; 20 | 21 | const int COUNT = 4; 22 | const int WORK = 10'000'000; 23 | 24 | 25 | #define BOUND ((C_blocks - 1 ) / 4 + 1) 26 | 27 | #if INT8 28 | struct thread_data { 29 | const int8_t * __restrict__ AB_val; 30 | const int * __restrict__ AB_bias; 31 | const int8_t * __restrict__ BC; 32 | int8_t * AC; 33 | const float * scale; 34 | int start; 35 | int end; 36 | }; 37 | #else 38 | struct thread_data { 39 | const float * __restrict__ AB_val; 40 | const float * __restrict__ AB_bias; 41 | const float * __restrict__ BC; 42 | float * AC; 43 | int start; 44 | int end; 45 | }; 46 | #endif 47 | struct thread_data_reps { 48 | thread_data * arg; 49 | int reps; 50 | }; 51 | 52 | 53 | typedef uint16_t offset_t; 54 | #define PTR_OFFSET_SZ sizeof(offset_t) 55 | //taken from https://embeddedartistry.com/blog/2017/02/22/generating-aligned-memory/ 56 | #ifndef align_up 57 | #define align_up(num, align) \ 58 | (((num) + ((align) - 1)) & ~((align) - 1)) 59 | #endif 60 | void * aligned_malloc(size_t align, size_t size) 61 | { 62 | void * ptr = NULL; 63 | assert((align & (align - 1)) == 0); 64 | if(align && size) 65 | { 66 | uint32_t hdr_size = PTR_OFFSET_SZ + (align - 1); 67 | void * p = malloc(size + hdr_size); 68 | 69 | if(p) 70 | { 71 | ptr = (void *) align_up(((uintptr_t)p + PTR_OFFSET_SZ), align); 72 | 73 | *((offset_t *)ptr - 1) = 74 | (offset_t)((uintptr_t)ptr - (uintptr_t)p); 75 | 76 | } // else NULL, could not malloc 77 | } //else NULL, invalid arguments 78 | 79 | return ptr; 80 | } 81 | 82 | void aligned_free(void * ptr) 83 | { 84 | assert(ptr); 85 | 86 | offset_t offset = *((offset_t *)ptr - 1); 87 | 88 | void * p = (void *)((uint8_t *)ptr - offset); 89 | free(p); 90 | } 91 | 92 | #define THREADS 4 93 | 94 | #include 95 | 96 | void fillvector(float *data, int n) { 97 | for(int i=0; i(); 140 | assert(arr1.word_size == 1); 141 | #else 142 | float * BC_unaligned = arr1.data(); 143 | assert(arr1.word_size == sizeof(float)); 144 | #endif 145 | std::cout << B_dim << " " << C_dim << std::endl; 146 | assert(arr1.shape.size()==2 && arr1.shape[0] == B_dim && arr1.shape[1] == C_dim); 147 | 148 | cnpy::NpyArray arr2 = cnpy::npy_load("AB_vals.npy"); 149 | #if INT8 150 | int8_t * AB_vals = arr2.data(); 151 | assert(arr2.word_size == 1); 152 | #else 153 | float * AB_vals = arr2.data(); 154 | assert(arr2.word_size == sizeof(float)); 155 | #endif 156 | assert(arr2.shape.size() ==1); 157 | int nnzs = arr2.shape[0]; 158 | 159 | cnpy::NpyArray arr3 = cnpy::npy_load("bias.npy"); 160 | #if INT8 161 | int * AB_bias = arr3.data(); 162 | assert(arr3.word_size == 4); 163 | #else 164 | float * AB_bias = arr3.data(); 165 | assert(arr3.word_size == sizeof(float)); 166 | //assert(arr3.shape.size() ==1 && arr3.shape[0] == A_dim); 167 | #endif 168 | 169 | #if INT8 170 | cnpy::NpyArray arr4 = cnpy::npy_load("scale.npy"); 171 | float * scale = arr4.data(); 172 | #endif 173 | 174 | 175 | cnpy::NpyArray arr7 = cnpy::npy_load("ref.npy"); 176 | #if INT8 177 | int8_t * ref1_stack = arr7.data(); 178 | std::memcpy(ref1,ref1_stack,A_dim * C_dim); 179 | #else 180 | float * ref1_stack = arr7.data(); 181 | std::memcpy(ref1,ref1_stack,A_dim * C_dim * 4); 182 | #endif 183 | 184 | #if X86 185 | #if INT8 186 | int8_t * BCs = (int8_t*) aligned_alloc(128, B_dim * C_dim); 187 | std::memcpy(&BCs[0],BC_unaligned,B_dim * C_dim); 188 | #else 189 | float* BCs = (float*) aligned_alloc(128, B_dim * C_dim * 4); 190 | std::memcpy(&BCs[0],BC_unaligned,B_dim * C_dim * 4); 191 | 192 | #endif 193 | #elif ARM 194 | float * BC = (float*) aligned_malloc(128, B_dim * C_dim * 4); 195 | #endif 196 | 197 | 198 | #if X86 199 | 200 | #if INT8 201 | 202 | // the intermediate results are in int32 so they need more space 203 | int8_t * result; 204 | result = (int8_t *)aligned_alloc(128, A_dim * C_dim *4 ); 205 | memset(result,0,A_dim * C_dim *4 ); 206 | 207 | #else 208 | float *result; 209 | result = (float *)aligned_alloc(128,A_dim * C_dim *sizeof(result)); 210 | memset(result,0,A_dim * C_dim * sizeof(result)); 211 | 212 | // for(int i = 0; i < A_dim*C_dim; i ++) 213 | // { 214 | // result[i] = 1.0f; 215 | // } 216 | 217 | #endif 218 | 219 | #elif ARM 220 | result = (float *) aligned_malloc(128, A_dim * C_dim * sizeof(result)); 221 | memset(result,0,A_dim * C_dim * sizeof(result)); 222 | #endif 223 | 224 | 225 | // let's pre-write the bias to the result. this is acceptable. 226 | /* for(int i = 0; i < A_dim; i ++) 227 | { 228 | for(int j = 0; j < C_dim; j ++) 229 | { 230 | result[i * C_dim + j] = AB_bias[i]; 231 | } 232 | } 233 | */ 234 | void *handle; 235 | 236 | char *error_str; 237 | 238 | 239 | using std::chrono::high_resolution_clock; 240 | using std::chrono::duration_cast; 241 | using std::chrono::duration; 242 | using std::chrono::milliseconds; 243 | 244 | auto t1 = high_resolution_clock::now(); 245 | 246 | handle = dlopen ("./test.so", RTLD_LAZY); 247 | if (!handle) { 248 | fputs (dlerror(), stderr); 249 | exit(1); 250 | } 251 | 252 | mm =(void* (*)(void *)) dlsym(handle, "_spmm"); 253 | if ((error_str = dlerror()) != NULL) { 254 | fputs(error_str, stderr); 255 | exit(1); 256 | } 257 | auto t2 = high_resolution_clock::now(); 258 | duration ms_double = t2 - t1; 259 | printf (" == Load shared library == \n== at %.5f milliseconds == \n ", ms_double.count() ); 260 | 261 | //printf (" Load at %.5f milliseconds == \n\n", (s_elapsed * 1000)); 262 | 263 | 264 | struct thread_data td[THREADS]; 265 | for(int i = 0; i < THREADS; i ++) 266 | { 267 | td[i].AB_val = AB_vals; 268 | td[i].AB_bias= AB_bias; 269 | td[i].BC = &BCs[0]; 270 | td[i].AC = result; 271 | td[i].scale = scale; 272 | #if MULTI 273 | td[i].start = i * BOUND; 274 | td[i].end = min(i * BOUND + BOUND, C_blocks); 275 | #else 276 | td[i].start = 0;//i * BOUND ; 277 | td[i].end = C_blocks;//min(i * BOUND + BOUND, C_blocks); 278 | #endif 279 | 280 | } 281 | #if MULTI 282 | // warm up omp thread pool 283 | #pragma omp parallel for 284 | for(int i = 0; i < THREADS; i ++) 285 | { 286 | mm(&td[i]); 287 | } 288 | #endif 289 | auto issed = 0; 290 | void * status; 291 | /*t1 = high_resolution_clock::now(); 292 | #if MULTI 293 | while(issed < 10 ) { 294 | #pragma omp parallel for 295 | for(int i = 0; i < THREADS; i ++) 296 | { 297 | mm(&td[i]); 298 | } 299 | issed += 1; 300 | } 301 | #else 302 | while(issed < 10 ) { 303 | mm(&td[0]); 304 | issed += 1; 305 | } 306 | #endif 307 | 308 | t2 = high_resolution_clock::now(); 309 | ms_double = t2 - t1; 310 | //memset(result,0,A_dim * C_dim * sizeof(result)); 311 | */ 312 | std::cout << "using one rep for SDE. Please use higher count on acutal hardware" << std::endl; 313 | // int reps = 20000 / ms_double.count(); 314 | int reps = 1; 315 | 316 | std::cout << reps << std::endl; 317 | t1 = high_resolution_clock::now(); 318 | 319 | issed = 0; 320 | #if MULTI 321 | while(issed < reps ) { 322 | #pragma omp parallel for 323 | for(int i = 0; i < THREADS; i ++) 324 | { 325 | mm(&td[i]); 326 | } 327 | issed += 1; 328 | } 329 | #else 330 | while(issed < reps ) { 331 | mm(&td[0]); 332 | issed += 1; 333 | } 334 | #endif 335 | t2 = high_resolution_clock::now(); 336 | ms_double = t2 - t1; 337 | printf (" == spmm microkernel == \n== at %.5f milliseconds == \n == %d reps == ", (ms_double.count() / reps), reps); 338 | 339 | 340 | // dlclose(handle); 341 | #if INT8 342 | cnpy::npy_save("cpu_output.npy",(char*)(&result[0]),{A_dim, C_dim},"w"); 343 | #else 344 | cnpy::npy_save("cpu_output.npy",(float*)(&result[0]),{A_dim, C_dim},"w"); 345 | 346 | #endif 347 | 348 | std::cout << result[0] << result[1] << result[2] << std::endl; 349 | } 350 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /code_gen_cpu_int8_block.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | import numpy as np 3 | from code_fragments import * 4 | from utils import * 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description='CodeGen V1') 8 | 9 | parser.add_argument('--A_dim', type=int, default=12) 10 | parser.add_argument('--B_dim', type=int, default=12) 11 | parser.add_argument('--C_dim', type=int, default=12) 12 | parser.add_argument('--AT', type=int, default=12) 13 | parser.add_argument('--C_blocks', type=int, default=12) 14 | parser.add_argument('--CT',type=int, default=1) 15 | parser.add_argument('--B_blocks',type=int,default=1) 16 | parser.add_argument('--infile', default=None, type=str) 17 | parser.add_argument('--infile_bias', default=None, type=str) 18 | parser.add_argument('--outfile_asm', default= None, type = str) 19 | parser.add_argument('--fuse',default=False,action='store_true') 20 | parser.add_argument('--x86',default=False,action='store_true') 21 | parser.add_argument('--arm',default=False,action='store_true') 22 | parser.add_argument('--relu',default=False,action='store_true') 23 | parser.add_argument('--no_row_skip',default=False,action='store_true') 24 | args = parser.parse_args() 25 | FUSE_END = args.fuse 26 | RELU = args.relu 27 | print(FUSE_END) 28 | A_dim = args.A_dim 29 | B_dim = args.B_dim 30 | C_dim = args.C_dim 31 | AT = args.AT 32 | C_blocks = args.C_blocks 33 | 34 | input_file = args.infile 35 | outfile_asm = args.outfile_asm 36 | #assert C_dim % C_blocks == 0 37 | TSZ = C_dim // C_blocks if C_dim % C_blocks == 0 else C_dim // C_blocks + 1 38 | 39 | X86 = args.x86 40 | ARM = args.arm 41 | NRS = args.no_row_skip 42 | CT = args.CT 43 | B_blocks = args.B_blocks 44 | BLOCK = AT 45 | 46 | assert not (X86 and ARM) 47 | if X86: 48 | print("Generating X86 vector intrinsics") 49 | elif ARM: 50 | print("Generating Arm intrinsics") 51 | else: 52 | assert False 53 | 54 | VEC=16 55 | 56 | IN_FORMAT = "NCHW" 57 | OUT_FORMAT = "NCHW" 58 | 59 | input_file_bias = args.infile_bias 60 | if input_file_bias: 61 | bias = np.load(input_file_bias) 62 | 63 | #global AB_vals 64 | AB_vals = [] 65 | A_idx = [] 66 | B_idx = [] 67 | AB_block_offs = [0] 68 | #global off 69 | off = 0 70 | 71 | if X86: 72 | LOAD_CACHE_ASM = """ 73 | vmovdqu8 IDX1(%r8,%r11,1), %xmmNUM; 74 | vmovdqu8 IDX3(%r8,%r11,1), %xmmDA; 75 | vbroadcasti32x4 IDX2(%r8,%r11,1),%ymmNUM {%k1}; 76 | vbroadcasti32x4 IDX4(%r8,%r11,1), %ymmDA {%k1}; 77 | vpermt2d %zmmDA, %zmm25, %zmmNUM; 78 | vpshufb %zmm26, %zmmNUM, %zmmNUM; 79 | """ 80 | elif ARM: 81 | LOAD_CACHE = """ 82 | RC = vld1q_f32(&BC[IDX + C_offset + lane]); 83 | """ 84 | 85 | 86 | if X86: 87 | 88 | LOAD_WEIGHT_ASM = """vpbroadcastd OFF(%rcx), %zmmIDX; 89 | """ 90 | MAIN_PROGRAM_ASM="""vpdpbusd %zmmNUM,%zmmIDX, %zmmTAR; 91 | """ 92 | 93 | elif ARM: 94 | MAIN_PROGRAM =""" 95 | val = vdupq_n_f32(VAL); 96 | ACC[IDX_1] = vmlaq_f32(ACC[IDX_1], RC, val); 97 | """ 98 | 99 | def emit_load_block(index, currloadreg): 100 | 101 | return LOAD_CACHE_ASM.replace("IDX1",str(index[0])).replace("IDX2",str(index[1])). \ 102 | replace("IDX3",str(index[2])).replace("IDX4",str(index[3])).replace("NUM",str(currloadreg)).replace("DA",str(17)) 103 | 104 | def emit_compute_block(Ny_idx,vals,currloadreg): 105 | global off 106 | 107 | new_block_asm = "" 108 | for i in range(BLOCK): 109 | new_block_asm += LOAD_WEIGHT_ASM.replace("OFF",str(off * 4 + i * 4)).replace("IDX",str(31-i)) 110 | 111 | for i in range(CT): 112 | for j in range(BLOCK): 113 | new_block_asm += MAIN_PROGRAM_ASM.replace("NUM",str(currloadreg - i)).replace("IDX",str(31-j)).replace("TAR",str(i * BLOCK + j)) 114 | global AB_vals 115 | AB_vals.extend(vals) 116 | global A_idx 117 | A_idx.extend([Ny_idx] * 4) 118 | off += BLOCK 119 | return new_block_asm 120 | 121 | 122 | def ny_to_a(ny_idx,groupId,blockId, A_dim = None, A_offset = None): 123 | if A_offset is None: 124 | A_offset = blockId * (AT) 125 | return A_offset + ny_idx 126 | 127 | 128 | def generate_from_B(Ny_indices, B_indices,BA,block,NY,BB_offset,A_offset=None): 129 | 130 | asm = """ 131 | ..B1.NUM1: 132 | xorl %r10d, %r10d; 133 | ..B1.NUM2: 134 | imul $16, %r10d, %r11d; 135 | add %r9d, %r11d; 136 | movslq %r11d, %r11; 137 | add $CT, %r10d; 138 | 139 | """.replace("NUM1",str(BB_offset + block*2+2)).replace("NUM2",str(BB_offset + block * 2 + 3)).replace("STRIDE",str(8)).replace("CT",str(CT)) 140 | 141 | 142 | #print(A_offset) 143 | if input_file_bias is not None: 144 | for i in range(NY): 145 | for j in range(CT): 146 | if BB_offset > 0: 147 | asm += "\t\tvmovdqu32 " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4) ,%zmm" + str(i + j * AT) + ";\n" 148 | else: 149 | asm += "\tvpbroadcastd " + str(mapping[A_offset+i] * 4) + "(%rsi), %zmm" + str(i + AT * j) + ";\n" 150 | 151 | else: 152 | for i in range(NY): 153 | for j in range(CT): 154 | if BB_offset > 0: 155 | asm += "\t\tvmovups " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4), %zmm" + str(i + j * AT) + ";\n" 156 | else: 157 | asm += "\tvxorps " + "%zmm" + str(i + AT * j) + ",%zmm" + str(i + AT * j) + ",%zmm" + str(i + AT * j) + ";\n" 158 | 159 | done = set() 160 | loads = "" 161 | computes = "" 162 | 163 | TOK = 24 164 | currloadreg = TOK 165 | # pad end of the zipped list 166 | 167 | padded_Ny_indices = [] 168 | padded_B_indices = [] 169 | 170 | # For some wierd legacy reasons, I call A_indices Ny_indices, which are A_indices relative to the block 171 | # But we are basically going to take all the Ny_indices and corresponding B_indices, and pad them with -1 so that length is multiple of 4 172 | # note we don't require 4 nonzero weights to be sequential, we can just pack them together and inline their corresponding read addresses in BC into the code 173 | 174 | counter = 0 175 | for ny_idx, b_idx in zip(Ny_indices,B_indices): 176 | #assert ny_idx == 0 # for now. We are going to handle AT for quantized at a later date if at all. 177 | if ny_idx != 0 : 178 | continue # we are going to just process the first element in each A tile. 179 | #print(ny_idx, b_idx) 180 | padded_Ny_indices.append(ny_idx) 181 | padded_B_indices.append(b_idx) 182 | counter += 1 183 | pad_len = ((counter - 1) // 4 + 1 ) * 4 - counter 184 | padded_Ny_indices.extend([-1] * pad_len) 185 | padded_B_indices.extend([-1] * pad_len) 186 | #print(padded_B_indices,len(padded_B_indices)) 187 | for pos in range(0,len(padded_Ny_indices),4): 188 | b_indices = padded_B_indices[pos:pos+4] 189 | currloadreg = TOK #(currloadreg - TOK + 1) % 6 + TOK 190 | asm += loads 191 | asm += computes 192 | loads = "" 193 | computes = "" 194 | ny_idx = 0 195 | a_idx = ny_to_a(ny_idx,0,block,A_dim = A_dim, A_offset=A_offset) 196 | global B_idx 197 | 198 | # this really complicated stuff is used to deal with cases when there are fewer than 4 nonzeros left in the B-row. 199 | # I suppose you could just pad the weights with zero and take away this custom logic -- you will generate marginally more instructions 200 | # but does it really matter 201 | 202 | if -1 in b_indices: 203 | assert(b_indices[0] != -1) 204 | for i in range(CT): 205 | load_block_asm = """ 206 | vxorps %zmm29, %zmm29, % zmm29; 207 | vxorps %zmmNUM, %zmmNUM, %zmmNUM; 208 | vmovdqu8 IDX1(%r8,%r11,1), %xmmNUM; 209 | """.replace("IDX1",str(b_indices[0] * C_dim + i * VEC)).replace("NUM",str(currloadreg-i)) 210 | if b_indices[2] != -1: 211 | load_block_asm += """ 212 | vmovdqu8 IDX3(%r8,%r11,1), %xmm29; 213 | """.replace("IDX3",str(b_indices[2] * C_dim + i * VEC)) 214 | if b_indices[1] != -1: 215 | load_block_asm += """ 216 | vbroadcasti32x4 IDX2(%r8,%r11,1),%ymmNUM {%k1}; 217 | """.replace("IDX2",str(b_indices[1] * C_dim + i * VEC)).replace("NUM",str(currloadreg-i)) 218 | if b_indices[3] != -1: 219 | load_block_asm += """ 220 | vbroadcasti32x4 IDX3(%r8,%r11,1),%ymm29 {%k1}; 221 | """.replace("IDX3",str(b_indices[3] * C_dim + i * VEC)) 222 | load_block_asm += """ 223 | vpermt2d %zmm29, %zmm25, %zmmNUM; 224 | vpshufb %zmm26, %zmmNUM, %zmmNUM; 225 | """.replace("NUM",str(currloadreg-i)) 226 | loads += load_block_asm 227 | #print(b_indices) 228 | num_vals = np.where(np.array(b_indices) == -1)[0][0] 229 | #print(num_vals) 230 | 231 | # we are going to keep track of the nonzero values in the program access order. This is the array we save to disk as the storage of our sparse matrix. 232 | values = [] 233 | for k in range(BLOCK): 234 | values.append( np.array([BA[b_indices[i],a_idx + k] for i in range(num_vals)] + [0 for j in range(4-num_vals)]).astype(np.int8)) 235 | values = np.hstack(values) 236 | B_idx.extend(b_indices) 237 | compute_block_asm = emit_compute_block(ny_idx , values, currloadreg ) 238 | computes += compute_block_asm 239 | 240 | 241 | else: 242 | for i in range(CT): 243 | load_block_asm = emit_load_block([k * C_dim + i * VEC for k in b_indices], currloadreg - i) 244 | loads += load_block_asm 245 | 246 | values = [] 247 | for i in range(BLOCK): 248 | values.append(BA[b_indices,a_idx + i]) 249 | values = np.hstack(values) 250 | 251 | B_idx.extend(b_indices) 252 | compute_block_asm = emit_compute_block(ny_idx , values, currloadreg ) 253 | computes += compute_block_asm 254 | 255 | done.add(ny_idx) 256 | 257 | asm += loads 258 | asm += computes 259 | 260 | #print(block,group) 261 | global AB_block_offs 262 | AB_block_offs.append(len(AB_vals)) 263 | 264 | return asm, done 265 | 266 | 267 | def get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [0,B_dim]): 268 | BA = BA[B_bounds[0]:B_bounds[1]] 269 | Ny_indices = [] 270 | B_indices = [] 271 | for B_idx in range(B_dim // B_blocks): 272 | for ny in range(block_NY): 273 | A_idx = ny_to_a(ny,0,block,A_dim = A_dim, A_offset=A_offset) 274 | if np.abs(BA[B_idx,A_idx]) > EPS: 275 | B_indices.append(B_idx + B_bounds[0]) 276 | Ny_indices.append(ny) 277 | 278 | return Ny_indices, B_indices 279 | 280 | def no_load_balance(BA): 281 | 282 | #assert A_dim % A_blocks == 0 283 | interval = AT 284 | 285 | bounds = [interval * i for i in range(A_blocks)] + [A_dim] 286 | return bounds , interval 287 | 288 | def load_balancer2(BA): 289 | 290 | total_nnz = (np.abs(BA) > EPS).sum() 291 | nnz_per_block = total_nnz / A_blocks 292 | sums = np.sum(np.abs(BA) > EPS, axis = 0) 293 | cs = np.cumsum(sums) 294 | bounds = [np.argmax(cs > nnz_per_block * i) for i in range(A_blocks)] 295 | bounds = bounds + [A_dim] 296 | nnzs = np.diff(bounds) 297 | NY = np.max(nnzs) 298 | return bounds, NY 299 | 300 | 301 | # name is the name of the numpy file 302 | def gencode(BA,C_dim,A_blocks,C_blocks,name=None): 303 | asm_program = """ 304 | # -- Begin _spmm 305 | .text 306 | # mark_begin; 307 | .align 16,0x90 308 | .globl _spmm 309 | # --- mm(void *) 310 | _spmm: 311 | # parameter 1: %rdi 312 | ..B1.1: # Preds ..B1.0 313 | # Execution count [9.00e-01] 314 | .cfi_startproc 315 | ..___tag_value__spmm.1: 316 | ..L2: 317 | #45.1 318 | pushq %rbp #45.1 319 | .cfi_def_cfa_offset 16 320 | movq %rsp, %rbp #45.1 321 | .cfi_def_cfa 6, 16 322 | .cfi_offset 6, -16 323 | andq $-32, %rsp #45.1 324 | subq $96, %rsp #45.1 325 | mov $0xf0 , %ebx; 326 | kmovb %ebx, %k1 327 | movq (%rdi), %rcx # the first argument which is packed nonzero values pointer 328 | movq 8(%rdi), %rsi # the second argument which is bias values pointer 329 | movq 16(%rdi), %r8 # the third argument which is input matrix pointer 330 | movq 24(%rdi), %rdx # the fourth argument which is output matrix pointer 331 | movq 32(%rdi), %rbx # the scale 332 | movl 44(%rdi), %eax # end iteration count in the C dimension, useful for multithreading 333 | movl 40(%rdi), %edi # start iteration count in the C dimension, useful for multithreading 334 | decl %eax 335 | decl %edi 336 | imul $TSZ, %eax, %r9d 337 | 338 | vpmovzxbd vpermt2d_control(%rip), % zmm25; # initialize the control avx vectors which we are going to use for permutes and shuffles 339 | vbroadcasti32x4 vpshufb_control(%rip), % zmm26; 340 | 341 | """.replace("TSZ",str(TSZ)) 342 | 343 | #assert A_dim % A_blocks == 0 344 | #assert C_dim % C_blocks == 0 345 | B_dim = BA.shape[0] 346 | 347 | # can try different load balancing schemes here. usually for random sparsity patterns no need to load balance at all. 348 | bounds, NY = no_load_balance(BA) 349 | #bounds, NY = load_balancer2(BA) 350 | 351 | assert B_dim % B_blocks == 0 352 | block_size = B_dim // B_blocks 353 | for b_block in range(B_blocks): 354 | 355 | # basic block offset. Also used to determine if B_blocks == 0, which means different basic block initialization/termination 356 | bb_offset = b_block * A_blocks * 2 357 | for block in range(A_blocks): 358 | A_offset = bounds[block] 359 | block_NY = bounds[block+1] - A_offset 360 | # if the bounds are fixed, i.e. no load balance, then block_NY should be the same every iteration. 361 | 362 | # this gets the indices of the nonzero values in this block 363 | Ny_indices, B_indices = get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [b_block * block_size, (b_block + 1) * block_size]) 364 | 365 | # generate the unrolled code for this basic block 366 | asm, done = generate_from_B(Ny_indices,B_indices,BA,block,block_NY,bb_offset, A_offset=A_offset) 367 | asm_program += textwrap.indent(asm,"\t") + "\n" 368 | 369 | # generate the epilogue logic. This is different depending on B_blocks value (should we cache intermediate results or write results with post-op to output) 370 | if b_block == B_blocks - 1: 371 | # any post op implementation goes here. The following sequence implements an optional value and then dequantization. 372 | for i in range(block_NY): 373 | asm_program += "\t\tvbroadcastss " + str(mapping[A_offset + i] * 4) + "(%rbx), %zmm20;\n" 374 | for j in range(CT): 375 | if RELU: 376 | asm_program += "\t\tvmaxsb %zmm" + str(i + j * AT) + ", %zmm27, %zmm" + str(i + j * AT) + ";\n" 377 | asm_program += "\t\tvcvtdq2ps {rn-sae}, %zmm" + str(i + j * AT) + ",%zmm" + str(i + j * AT) + ";\n" 378 | asm_program += "\t\tvmulps %zmm" + str(i + j * AT) + ",%zmm20, %zmm" + str(i + j * AT) + ";\n" 379 | asm_program += "\t\tvcvtps2dq {rn-sae}, %zmm" + str(i + j * AT) + ",%zmm" + str(i + j * AT) + ";\n" 380 | asm_program += "\t\tvpmovdb %zmm" + str(i + j * AT) + ",%xmm" + str(i + j * AT) + ";\n" 381 | 382 | asm_program += """ 383 | vinserti32x4 $1,%xmmONE,%zmmZERO,%zmmZERO; 384 | vinserti32x4 $2,%xmmTWO,%zmmZERO,%zmmZERO; 385 | vinserti32x4 $3,%xmmTHREE,%zmmZERO,%zmmZERO; 386 | """.replace("ZERO",str(i)).replace("ONE",str(i + AT)).replace("TWO",str(i + 2 * AT)).replace("THREE",str(i + 3 * AT)) 387 | asm_program += "vmovdqu32 %zmm" + str(i) + ", " + str(mapping[A_offset + i] * C_dim ) + "(%rdx,%r11,1);\n" 388 | else: 389 | for i in range(block_NY): 390 | for j in range(CT): 391 | asm_program += "\t\tvmovdqu32 %zmm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 392 | 393 | 394 | asm_program += """ 395 | cmpl $END, %r10d; 396 | jb ..B1.NUM; 397 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 398 | 399 | 400 | asm_program += """ 401 | ..B1.NUM1: # Preds ..B1.17 402 | # Execution count [2.80e+01] 403 | decl %eax #44.37 404 | subl $TSZ, %r9d #44.37 405 | cmpl %eax, %edi #44.33 406 | jl ..B1.2 # Prob 96% #44.33 407 | # LOE rcx rbx rbp rsi rdi r12 r13 r14 r15 eax dl ymm15 408 | ..B1.NUM2: # Preds ..B1.18 409 | # Execution count [1.00e+00] 410 | vzeroupper #2398.1 411 | movq %rbp, %rsp 412 | popq %rbp 413 | #call pthread_exit@PLT #2416.1 414 | ret 415 | ..___tag_value__spmm.13: 416 | .align 16,0x90 417 | # LOE 418 | .cfi_endproc 419 | # mark_end; 420 | .type _spmm,@function 421 | #.size _spmm,-_spmm 422 | ..LN_spmm.0: 423 | .section .rodata 424 | .balign 32 425 | vpermt2d_control: .byte 0,4,16,20, 1,5,17,21, 2, 6, 18, 22,3,7,19,23 426 | vpshufb_control: .byte 0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15 427 | # -- End _spmm 428 | 429 | """.replace("TSZ",str(TSZ)).replace("CBLOCKS",str(C_blocks)).replace("NUM1",str(B_blocks * A_blocks *2 + 2)).replace("NUM2",str(B_blocks * A_blocks * 2 + 3)) 430 | 431 | open(outfile_asm,"w").write(asm_program) 432 | 433 | 434 | 435 | BA = np.load(input_file) 436 | print(BA.shape) 437 | BA = BA.squeeze() 438 | 439 | """ 440 | We are going to redo BA here to remove some empty rows 441 | """ 442 | if NRS: 443 | A_dim = BA.shape[1] 444 | mapping = {i:i for i in range(A_dim)} 445 | else: 446 | nnz_cols = np.unique(np.where(BA)[1]) 447 | mapping = {i : nnz_cols[i] for i in range(len(nnz_cols))} 448 | #print(mapping) 449 | BA = BA[:,nnz_cols] 450 | A_dim = len(nnz_cols) 451 | if A_dim % AT == 0: 452 | A_blocks = A_dim // AT 453 | else: 454 | A_blocks = A_dim // AT + 1 455 | 456 | 457 | print("Reduced A dimension " + str(A_dim)) 458 | gencode(BA,C_dim,A_blocks,C_blocks,name=input_file) 459 | np.save("AB_vals.npy",np.array(AB_vals)) 460 | np.save("AB_block_off.npy",np.array(AB_block_offs).astype(np.int32)) 461 | np.save("A_idx.npy",np.array(A_idx).astype(np.int32)) 462 | np.save("B_idx.npy",np.array(B_idx).astype(np.int32)) 463 | 464 | if input_file_bias: 465 | np.save("bias.npy",bias.squeeze()) 466 | -------------------------------------------------------------------------------- /code_gen_cpu_int8.py: -------------------------------------------------------------------------------- 1 | # this program basically does a constexpr and generates cuda code 2 | import textwrap 3 | import numpy as np 4 | from code_fragments import * 5 | from utils import * 6 | 7 | 8 | import argparse 9 | parser = argparse.ArgumentParser(description='CodeGen V1') 10 | 11 | parser.add_argument('--A_dim', type=int, default=12) 12 | parser.add_argument('--B_dim', type=int, default=12) 13 | parser.add_argument('--C_dim', type=int, default=12) 14 | parser.add_argument('--AT', type=int, default=12) 15 | parser.add_argument('--C_blocks', type=int, default=12) 16 | parser.add_argument('--CT',type=int, default=1) 17 | parser.add_argument('--B_blocks',type=int,default=1) 18 | parser.add_argument('--Gy', type=int, default=12) 19 | parser.add_argument('--infile', default=None, type=str) 20 | parser.add_argument('--infile_bias', default=None, type=str) 21 | parser.add_argument('--outfile', default=None, type=str) 22 | parser.add_argument('--outfile_asm', default= None, type = str) 23 | parser.add_argument('--in_format', default="NCHW",type=str) 24 | parser.add_argument('--out_format', default="NCHW",type=str) 25 | parser.add_argument('--Tsb',type=float,default=1) 26 | parser.add_argument('--fuse',default=False,action='store_true') 27 | parser.add_argument('--x86',default=False,action='store_true') 28 | parser.add_argument('--arm',default=False,action='store_true') 29 | parser.add_argument('--threads',type = int, default=4) 30 | parser.add_argument('--avx512',default=False,action='store_true') 31 | parser.add_argument('--no_relu',default=False,action='store_true') 32 | parser.add_argument('--no_row_skip',default=False,action='store_true') 33 | args = parser.parse_args() 34 | GY = args.Gy 35 | FUSE_END = args.fuse 36 | NO_RELU = args.no_relu 37 | print(FUSE_END) 38 | TSB_MULT = args.Tsb 39 | A_dim = args.A_dim 40 | B_dim = args.B_dim 41 | C_dim = args.C_dim 42 | THREADS = args.threads 43 | AT = args.AT 44 | C_blocks = args.C_blocks 45 | AVX512 = False 46 | if args.avx512: 47 | AVX512 = True 48 | input_file = args.infile 49 | outfile = args.outfile 50 | outfile_asm = args.outfile_asm 51 | #assert C_dim % C_blocks == 0 52 | GSY = C_dim // C_blocks 53 | TSB =int( GSY * TSB_MULT) 54 | TSZ = C_dim // C_blocks if C_dim % C_blocks == 0 else C_dim // C_blocks + 1 55 | 56 | X86 = args.x86 57 | ARM = args.arm 58 | NRS = args.no_row_skip 59 | CT = args.CT 60 | B_blocks = args.B_blocks 61 | 62 | assert CT %4 == 0 63 | 64 | assert not (X86 and ARM) 65 | if X86: 66 | print("Generating X86 vector intrinsics") 67 | elif ARM: 68 | print("Generating Arm intrinsics") 69 | else: 70 | assert False 71 | 72 | if AVX512: 73 | VEC = 16 74 | else: 75 | VEC = 8 76 | 77 | IN_FORMAT = args.in_format 78 | OUT_FORMAT = args.out_format 79 | 80 | if IN_FORMAT == "NHWC" or OUT_FORMAT == "NHWC": 81 | assert False 82 | 83 | input_file_bias = args.infile_bias 84 | if input_file_bias: 85 | bias = np.load(input_file_bias) 86 | 87 | #global AB_vals 88 | AB_vals = [] 89 | A_idx = [] 90 | B_idx = [] 91 | AB_block_offs = [0] 92 | #global off 93 | off = 0 94 | 95 | if X86: 96 | LOAD_CACHE_ASM = """ 97 | vmovdqu8 IDX1(%r8,%r11,1), %xmmNUM; 98 | vmovdqu8 IDX3(%r8,%r11,1), %xmm29; 99 | vbroadcasti32x4 IDX2(%r8,%r11,1),%ymmNUM {%k1}; 100 | vbroadcasti32x4 IDX4(%r8,%r11,1), %ymm29 {%k1}; 101 | vpermt2d %zmm29, %zmm27, %zmmNUM; 102 | vpshufb %zmm28, %zmmNUM, %zmmNUM; 103 | """ 104 | elif ARM: 105 | LOAD_CACHE = """ 106 | RC = vld1q_f32(&BC[IDX + C_offset + lane]); 107 | """ 108 | 109 | 110 | if X86: 111 | if AVX512: 112 | 113 | LOAD_WEIGHT_ASM = """vpbroadcastd OFF(%rcx), %zmm31; 114 | """ 115 | MAIN_PROGRAM_ASM="""vpdpbusd %zmmNUM,%zmm31,%zmmTAR; 116 | """ 117 | 118 | elif ARM: 119 | MAIN_PROGRAM =""" 120 | val = vdupq_n_f32(VAL); 121 | ACC[IDX_1] = vmlaq_f32(ACC[IDX_1], RC, val); 122 | """ 123 | 124 | def emit_load_block(index, currloadreg): 125 | 126 | return LOAD_CACHE_ASM.replace("IDX1",str(index[0])).replace("IDX2",str(index[1])). \ 127 | replace("IDX3",str(index[2])).replace("IDX4",str(index[3])).replace("NUM",str(currloadreg)) 128 | 129 | def emit_compute_block(Ny_idx,vals,currloadreg, virg=False): 130 | global off 131 | new_block_asm = LOAD_WEIGHT_ASM.replace("OFF",str(off * 4 )) 132 | for i in range(CT): 133 | new_block_asm += MAIN_PROGRAM_ASM.replace("TAR",str(Ny_idx +i * AT)).replace("NUM",str(currloadreg - i)) 134 | global AB_vals 135 | AB_vals.extend(vals) 136 | global A_idx 137 | A_idx.extend([Ny_idx] * 4) 138 | off += 1 139 | return new_block_asm 140 | 141 | 142 | def ny_to_a(ny_idx,groupId,blockId, A_dim = None, A_offset = None): 143 | if A_offset is None: 144 | A_offset = blockId * (AT) 145 | return A_offset + ny_idx 146 | 147 | 148 | def generate_from_B(Ny_indices, B_indices,BA,block,NY,BB_offset, GY = None,A_offset=None): 149 | 150 | program = "" 151 | 152 | asm = """ 153 | ..B1.NUM1: 154 | xorl %r10d, %r10d; 155 | ..B1.NUM2: 156 | imul $16, %r10d, %r11d; 157 | add %r9d, %r11d; 158 | movslq %r11d, %r11; 159 | add $CT, %r10d; 160 | 161 | """.replace("NUM1",str(BB_offset + block*2+2)).replace("NUM2",str(BB_offset + block * 2 + 3)).replace("STRIDE",str(8)).replace("CT",str(CT)) 162 | 163 | 164 | #print(A_offset) 165 | if input_file_bias is not None: 166 | for i in range(NY): 167 | for j in range(CT): 168 | if BB_offset > 0: 169 | asm += "\t\tvmovups " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4)" + (",%zmm" if AVX512 else ",%ymm") + str(i + j * AT) + ";\n" 170 | else: 171 | asm += "\tvpbroadcastd " + str(mapping[A_offset+i] * 4) + "(%rsi), %zmm" + str(i + AT * j) + ";\n" 172 | 173 | #asm += "\tvpbroadcastd " + "%xmm" + str(i + AT * j) + (", %zmm" if AVX512 else "(%rsi), %ymm") + str(i + AT * j) + ";\n" 174 | else: 175 | for i in range(NY): 176 | for j in range(CT): 177 | if BB_offset > 0: 178 | asm += "\t\tvmovups " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4)" + (",%zmm" if AVX512 else ",%ymm") + str(i + j * AT) + ";\n" 179 | else: 180 | asm += "\tvxorps " + ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + "," + \ 181 | ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + "," + ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + ";\n" 182 | 183 | done = set() 184 | loads = "" 185 | computes = "" 186 | 187 | if AVX512: 188 | TOK = 24 189 | else: 190 | TOK = 13 191 | currloadreg = TOK 192 | # pad end of the zipped list 193 | 194 | padded_Ny_indices = [] 195 | padded_B_indices = [] 196 | 197 | counter = 0 198 | old_B_idx = -1 199 | for ny_idx, b_idx in zip(Ny_indices[0],B_indices[0]): 200 | assert ny_idx == 0 # for now. We are going to handle AT for quantized at a later date if at all. 201 | if(b_idx < ny_idx): 202 | pad_len = ((counter - 1) // 4 + 1 ) * 4 - counter 203 | padded_Ny_indices.extend([-1] * pad_len) 204 | padded_B_indices.extend([-1] * pad_len) 205 | padded_Ny_indices.append(ny_idx) 206 | padded_B_indices.append(b_idx) 207 | counter = 1 208 | else: 209 | padded_Ny_indices.append(ny_idx) 210 | padded_B_indices.append(b_idx) 211 | counter += 1 212 | pad_len = ((counter - 1) // 4 + 1 ) * 4 - counter 213 | padded_Ny_indices.extend([-1] * pad_len) 214 | padded_B_indices.extend([-1] * pad_len) 215 | #print(padded_B_indices,len(padded_B_indices)) 216 | #for ny_idx, b_idx in zip(padded_Ny_indices[0],padded_B_indices[0]): 217 | for pos in range(0,len(padded_Ny_indices),4): 218 | b_indices = padded_B_indices[pos:pos+4] 219 | currloadreg = TOK #(currloadreg - TOK + 1) % 6 + TOK 220 | asm += loads 221 | asm += computes 222 | loads = "" 223 | computes = "" 224 | ny_idx = 0 225 | a_idx = ny_to_a(ny_idx,0,block,A_dim = A_dim, A_offset=A_offset) 226 | global B_idx 227 | 228 | if -1 in b_indices: 229 | assert(b_indices[0] != -1) 230 | for i in range(CT): 231 | load_block_asm = """ 232 | vxorps %zmm29, %zmm29, % zmm29; 233 | vxorps %zmmNUM, %zmmNUM, %zmmNUM; 234 | vmovdqu8 IDX1(%r8,%r11,1), %xmmNUM; 235 | """.replace("IDX1",str(b_indices[0] * C_dim + i * VEC)).replace("NUM",str(currloadreg-i)) 236 | if b_indices[2] != -1: 237 | load_block_asm += """ 238 | vmovdqu8 IDX3(%r8,%r11,1), %xmm29; 239 | """.replace("IDX3",str(b_indices[2] * C_dim + i * VEC)) 240 | if b_indices[1] != -1: 241 | load_block_asm += """ 242 | vbroadcasti32x4 IDX2(%r8,%r11,1),%ymmNUM {%k1}; 243 | """.replace("IDX2",str(b_indices[1] * C_dim + i * VEC)).replace("NUM",str(currloadreg-i)) 244 | if b_indices[3] != -1: 245 | load_block_asm += """ 246 | vbroadcasti32x4 IDX3(%r8,%r11,1),%ymm29 {%k1}; 247 | """.replace("IDX3",str(b_indices[3] * C_dim + i * VEC)) 248 | load_block_asm += """ 249 | vpermt2d %zmm29, %zmm27, %zmmNUM; 250 | vpshufb %zmm28, %zmmNUM, %zmmNUM; 251 | """.replace("NUM",str(currloadreg-i)) 252 | loads += load_block_asm 253 | #print(b_indices) 254 | num_vals = np.where(np.array(b_indices) == -1)[0][0] 255 | #print(num_vals) 256 | values = np.array([BA[b_indices[i],a_idx] for i in range(num_vals)] + [0 for j in range(4-num_vals)]).astype(np.int8) 257 | #print(values) 258 | B_idx.extend(b_indices) 259 | compute_block_asm = emit_compute_block(ny_idx , values, currloadreg , virg = ny_idx not in done) 260 | computes += compute_block_asm 261 | 262 | 263 | else: 264 | for i in range(CT): 265 | load_block_asm = emit_load_block([k * C_dim + i * VEC for k in b_indices], currloadreg - i) 266 | loads += load_block_asm 267 | 268 | values = BA[b_indices,a_idx] 269 | B_idx.extend(b_indices) 270 | compute_block_asm = emit_compute_block(ny_idx , values, currloadreg , virg = ny_idx not in done) 271 | computes += compute_block_asm 272 | 273 | done.add(ny_idx) 274 | 275 | asm += loads 276 | asm += computes 277 | 278 | 279 | 280 | 281 | 282 | #print(block,group) 283 | #program += GROUP_CONTROL_END + "\n" 284 | global AB_block_offs 285 | AB_block_offs.append(len(AB_vals)) 286 | 287 | return program, asm, done 288 | 289 | 290 | def get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [0,B_dim], GY=None): 291 | 292 | BA = BA[B_bounds[0]:B_bounds[1]] 293 | Ny_indices = [[] for i in range(GY)] 294 | B_indices = [[] for i in range(GY)] 295 | nnz = np.sum(np.abs(BA[:,A_offset:A_offset + block_NY]) > EPS ) 296 | nnz_per_group = nnz // GY 297 | curr_group = 0 298 | curr_nnz = 0 299 | for B_idx in range(B_dim // B_blocks): 300 | for ny in range(block_NY): 301 | assert curr_group < GY 302 | A_idx = ny_to_a(ny,curr_group,block,A_dim = A_dim, A_offset=A_offset) 303 | if np.abs(BA[B_idx,A_idx]) > EPS: 304 | B_indices[curr_group].append(B_idx + B_bounds[0]) 305 | Ny_indices[curr_group].append(ny) 306 | curr_nnz += 1 307 | if curr_nnz > nnz_per_group: 308 | curr_group += 1 309 | curr_nnz = 0 310 | 311 | return Ny_indices, B_indices 312 | 313 | def no_load_balance(BA): 314 | 315 | #assert A_dim % A_blocks == 0 316 | interval = AT 317 | 318 | bounds = [interval * i for i in range(A_blocks)] + [A_dim] 319 | 320 | return bounds , interval 321 | 322 | def load_balancer2(BA): 323 | 324 | total_nnz = (np.abs(BA) > EPS).sum() 325 | nnz_per_block = total_nnz / A_blocks 326 | sums = np.sum(np.abs(BA) > EPS, axis = 0) 327 | cs = np.cumsum(sums) 328 | bounds = [np.argmax(cs > nnz_per_block * i) for i in range(A_blocks)] 329 | bounds = bounds + [A_dim] 330 | nnzs = np.diff(bounds) 331 | NY = np.max(nnzs) 332 | return bounds, NY 333 | 334 | 335 | # name is the name of the numpy file 336 | def gencode(BA,outfile,C_dim,A_blocks,C_blocks,GY,name=None): 337 | program = "" 338 | asm_program = """ 339 | # -- Begin _spmm 340 | .text 341 | # mark_begin; 342 | .align 16,0x90 343 | .globl _spmm 344 | # --- mm(void *) 345 | _spmm: 346 | # parameter 1: %rdi 347 | ..B1.1: # Preds ..B1.0 348 | # Execution count [9.00e-01] 349 | .cfi_startproc 350 | ..___tag_value__spmm.1: 351 | ..L2: 352 | #45.1 353 | pushq %rbp #45.1 354 | .cfi_def_cfa_offset 16 355 | movq %rsp, %rbp #45.1 356 | .cfi_def_cfa 6, 16 357 | .cfi_offset 6, -16 358 | andq $-32, %rsp #45.1 359 | subq $96, %rsp #45.1 360 | movq (%rdi), %rcx #47.38 361 | movq 8(%rdi), %rsi #48.46 362 | movq 16(%rdi), %r8 #49.41 363 | movq 24(%rdi), %rdx #50.22 364 | movl 36(%rdi), %eax 365 | movl 32(%rdi), %edi #51.21 366 | vxorps ZERO, ZERO, ZERO #59.19 367 | decl %eax 368 | decl %edi 369 | imul $TSZ, %eax, %r9d 370 | mov $0xf0 , %ebx; 371 | kmovb %ebx, %k1 372 | vpmovzxbd vpermt2d_control(%rip), % zmm27; 373 | vbroadcasti32x4 vpshufb_control(%rip), % zmm28; 374 | 375 | 376 | 377 | """.replace("BOUND",str(C_blocks//THREADS)).replace("TSZ",str(TSZ)).replace("ZERO","%zmm30" if AVX512 else "%ymm14") 378 | 379 | #assert A_dim % A_blocks == 0 380 | #assert C_dim % C_blocks == 0 381 | B_dim = BA.shape[0] 382 | 383 | # if IN_FORMAT == "NCHW" and OUT_FORMAT == "NCHW": 384 | # bounds, NY = load_balancer2(BA) 385 | # else: 386 | bounds, NY = no_load_balance(BA) 387 | 388 | program += START_NONFUSED.replace("OUTPUT_FORMAT",OUT_FORMAT).replace("INPUT_FORMAT",IN_FORMAT).replace("Ny",str(NY)).replace("GY",str(GY)).replace("A_dim",str(A_dim)).replace( 389 | "C_dim",str(C_dim)).replace("B_dim",str(B_dim)).replace("A_BLOCKS",str(A_blocks)).replace("C_BLOCKS",str(C_blocks)).replace("BOUND",str(C_blocks//4)).replace("X86_DEF",str(int(X86))).replace("ARM_DEF",str(int(ARM))) + "\n" 390 | 391 | assert B_dim % B_blocks == 0 392 | block_size = B_dim // B_blocks 393 | for b_block in range(B_blocks): 394 | bb_offset = b_block * A_blocks * 2 395 | for block in range(A_blocks): 396 | A_offset = bounds[block] 397 | block_NY = bounds[block+1] - A_offset 398 | program += BLOCK_CONTROL_START.replace("BLOCK", str(block)).replace("Ny",str(block_NY)) + "\n" 399 | 400 | 401 | Ny_indices, B_indices = get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [b_block * block_size, (b_block + 1) * block_size],GY=GY) 402 | #import pdb;pdb.set_trace() 403 | ccode, asm, done = generate_from_B(Ny_indices,B_indices,BA,block,block_NY,bb_offset, GY=GY,A_offset=A_offset) 404 | #ccode = generate_c_stem(block_NY) 405 | 406 | program += textwrap.indent(ccode,"\t") + "\n" 407 | asm_program += textwrap.indent(asm,"\t") + "\n" 408 | if OUT_FORMAT == "NCHW": 409 | if FUSE_END: 410 | if GY > 1: 411 | print("End fusion strategy not valid.") 412 | for i in range(block_NY): 413 | program += BLOCK_END_REDUCTION.replace("OFFSET",str(mapping[A_offset + i] * C_dim)).replace("IDX",str(i)).replace("BIAS",str(A_offset+i)) 414 | if not NO_RELU: 415 | for j in range(CT): 416 | asm_program += "\t\tvmaxps %ymm" + str(i + j * AT) + (", %zmm30," if AVX512 else ", %ymm14,") + "%ymm" + str(i + j * AT) + ";\n" 417 | for j in range(CT): 418 | asm_program += "\t\tvmovdqu32 %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 419 | asm_program += """ 420 | cmpl $END, %r10d; 421 | jb ..B1.NUM; 422 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 423 | program += "\t}" 424 | else: 425 | if b_block == B_blocks - 1: 426 | for i in range(block_NY): 427 | for j in range(CT): 428 | asm_program += "\t\tvpmovdb %zmm" + str(i + j * AT) + "," + "%xmm" + str(i + j * AT) + ";\n" 429 | 430 | asm_program += """ 431 | vinserti32x4 $1,%xmm1,%zmm0,%zmm0; 432 | vinserti32x4 $2,%xmm2,%zmm0,%zmm0; 433 | vinserti32x4 $3,%xmm3,%zmm0,%zmm0; 434 | 435 | """ 436 | asm_program += "vmovdqu32 %zmm" + str(i) + ", " + str(mapping[A_offset + i] * C_dim ) + "(%rdx,%r11,1);\n" 437 | else: 438 | for i in range(block_NY): 439 | for j in range(CT): 440 | asm_program += "\t\tvmovdqu32 %zmm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 441 | 442 | 443 | asm_program += """ 444 | cmpl $END, %r10d; 445 | jb ..B1.NUM; 446 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 447 | else: 448 | program += BLOCK_END_NHWC.replace("A_offset",str(A_offset)).replace("Ny",str(block_NY)).replace("A_BLOCKS",str(A_blocks)).replace( 449 | "C_BLOCKS", str(C_blocks)).replace("A_dim",str(A_dim)).replace("C_dim",str(C_dim)).replace("B_dim",str(B_dim)) + "\n" 450 | # program += BLOCK_CONTROL_END 451 | 452 | program += END_NONFUSED.replace("AB_sparse_tidy.npy",name) 453 | open(outfile,"w").write(program.replace("B_dim",str(B_dim))) 454 | asm_program += """ 455 | ..B1.NUM1: # Preds ..B1.17 456 | # Execution count [2.80e+01] 457 | decl %eax #44.37 458 | subl $TSZ, %r9d #44.37 459 | cmpl %eax, %edi #44.33 460 | jl ..B1.2 # Prob 96% #44.33 461 | # LOE rcx rbx rbp rsi rdi r12 r13 r14 r15 eax dl ymm15 462 | ..B1.NUM2: # Preds ..B1.18 463 | # Execution count [1.00e+00] 464 | vzeroupper #2398.1 465 | movq %rbp, %rsp 466 | popq %rbp 467 | #call pthread_exit@PLT #2416.1 468 | ret 469 | ..___tag_value__spmm.13: 470 | .align 16,0x90 471 | # LOE 472 | .cfi_endproc 473 | # mark_end; 474 | .type _spmm,@function 475 | #.size _spmm,-_spmm 476 | ..LN_spmm.0: 477 | .section .rodata 478 | .balign 32 479 | vpermt2d_control: .byte 0,4,16,20, 1,5,17,21, 2, 6, 18, 22,3,7,19,23 480 | vpshufb_control: .byte 0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15 481 | # -- End _spmm 482 | 483 | 484 | 485 | """.replace("TSZ",str(TSZ)).replace("CBLOCKS",str(C_blocks)).replace("NUM1",str(B_blocks * A_blocks *2 + 2)).replace("NUM2",str(B_blocks * A_blocks * 2 + 3)) 486 | 487 | if AVX512: 488 | asm_program = asm_program.replace("ymm","zmm") 489 | 490 | open(outfile_asm,"w").write(asm_program) 491 | 492 | 493 | 494 | BA = np.load(input_file) 495 | print(BA.shape) 496 | BA = BA.squeeze() 497 | 498 | """ 499 | We are going to redo BA here to remove some empty rows 500 | """ 501 | if NRS: 502 | A_dim = BA.shape[1] 503 | mapping = {i:i for i in range(A_dim)} 504 | else: 505 | nnz_cols = np.unique(np.where(BA)[1]) 506 | mapping = {i : nnz_cols[i] for i in range(len(nnz_cols))} 507 | #print(mapping) 508 | BA = BA[:,nnz_cols] 509 | A_dim = len(nnz_cols) 510 | if A_dim % AT == 0: 511 | A_blocks = A_dim // AT 512 | else: 513 | A_blocks = A_dim // AT + 1 514 | 515 | 516 | print("Reduced A dimension " + str(A_dim)) 517 | gencode(BA,outfile,C_dim,A_blocks,C_blocks,GY,name=input_file) 518 | np.save("AB_vals.npy",np.array(AB_vals)) 519 | np.save("AB_block_off.npy",np.array(AB_block_offs).astype(np.int32)) 520 | np.save("A_idx.npy",np.array(A_idx).astype(np.int32)) 521 | np.save("B_idx.npy",np.array(B_idx).astype(np.int32)) 522 | 523 | if input_file_bias: 524 | np.save("bias.npy",bias.squeeze()) 525 | -------------------------------------------------------------------------------- /code_gen_cpu.py: -------------------------------------------------------------------------------- 1 | # this program basically does a constexpr and generates cuda code 2 | import textwrap 3 | import numpy as np 4 | from code_fragments import * 5 | from utils import * 6 | 7 | FUNC_NAME = "spmm" 8 | EPI = "NONE" 9 | 10 | def emit_load_block(index, currloadreg): 11 | if X86: 12 | if AVX512: 13 | LOAD_CACHE = """ 14 | RC = _mm512_load_ps(&BC[IDX + C_offset + lane]); 15 | 16 | //RC = _mm256_load_ps(&BC[(C_offset + lane) * B_dim + IDX]); 17 | """ 18 | LOAD_CACHE_ASM = """vmovups IDX(%r8,%r11,4), %zmmNUM; 19 | """ 20 | else: 21 | LOAD_CACHE = """ 22 | RC = _mm256_load_ps(&BC[IDX + C_offset + lane]); 23 | 24 | //RC = _mm256_load_ps(&BC[(C_offset + lane) * B_dim + IDX]); 25 | """ 26 | LOAD_CACHE_ASM = """vmovups IDX(%r8,%r11,4), %ymmNUM; 27 | """ 28 | elif ARM: 29 | LOAD_CACHE = """ 30 | RC = vld1q_f32(&BC[IDX + C_offset + lane]); 31 | """ 32 | new_block = LOAD_CACHE.replace("IDX",str(index)) 33 | new_block_asm = LOAD_CACHE_ASM.replace("IDX",str(index * 4)).replace("NUM",str(currloadreg)) 34 | #new_block = LOAD_CACHE.replace("IDX",str(B_idx * 8)) 35 | return new_block, new_block_asm 36 | 37 | def emit_compute_block(Ny_idx,val,currloadreg, virg=False): 38 | 39 | global off 40 | if X86: 41 | if AVX512: 42 | LOAD_WEIGHT=""" 43 | val = _mm512_broadcast_ss(AB_val + OFF); 44 | """ 45 | MAIN_PROGRAM =""" 46 | //val = _mm256_set1_ps(VAL); 47 | ACC[IDX_1] = _mm512_fmadd_ps(RC, val, ACC[IDX_1]); 48 | """ 49 | LOAD_WEIGHT_ASM = """vbroadcastss OFF(%rcx), %zmm31; 50 | """ 51 | MAIN_PROGRAM_ASM="""vfmadd231ps %zmmNUM, %zmm31, %zmmIDX_1; 52 | """ 53 | 54 | else: 55 | LOAD_WEIGHT = """ 56 | val = _mm256_broadcast_ss(AB_val + OFF); 57 | """ 58 | MAIN_PROGRAM =""" 59 | ACC[IDX_1] = _mm256_fmadd_ps(RC, val, ACC[IDX_1]); 60 | """ 61 | LOAD_WEIGHT_ASM = """vbroadcastss OFF(%rcx), %ymm15; 62 | """ 63 | # LOAD_WEIGHT_ASM = """mov $VAL, %r12d; 64 | # movd %r12d, %xmm0; 65 | # vbroadcastss %xmm0, %ymm15; 66 | # """ 67 | MAIN_PROGRAM_ASM="""vfmadd231ps %ymmNUM, %ymm15, %ymmIDX_1; 68 | """ 69 | MAIN_PROGRAM_ASM_SUB="""vsubps %ymmNUM, %ymmIDX_1, %ymmIDX_1; 70 | """ 71 | MAIN_PROGRAM_ASM_ADD="""vaddps %ymmNUM, %ymmIDX_1, %ymmIDX_1; 72 | """ 73 | 74 | # MAIN_PROGRAM_ASM_SUB="""vpsubp %ymmNUM, %ymmIDX_1, %ymmIDX_1; 75 | # """ 76 | # MAIN_PROGRAM_ASM_ADD="""vpaddb %ymmNUM, %ymmIDX_1, %ymmIDX_1; 77 | # """ 78 | 79 | MAIN_PROGRAM_ASM_VIRG="""vbroadcastss OFF(%rcx), %ymm15; 80 | vmul231ps %ymmNUM, %ymm15, %ymmIDX_1; 81 | """ 82 | 83 | elif ARM: 84 | MAIN_PROGRAM =""" 85 | val = vdupq_n_f32(VAL); 86 | ACC[IDX_1] = vmlaq_f32(ACC[IDX_1], RC, val); 87 | """ 88 | 89 | if val != 1 and val != -1: 90 | new_block = LOAD_WEIGHT.replace("OFF",str(off )) 91 | new_block_asm = LOAD_WEIGHT_ASM.replace("OFF",str(off * 4 )).replace("VAL","0x" + str(float_to_hex(val)[2:])) 92 | for i in range(CT): 93 | new_block += MAIN_PROGRAM.replace("IDX_1",str(Ny_idx+ i * AT)) 94 | new_block_asm += MAIN_PROGRAM_ASM.replace("IDX_1",str(Ny_idx +i * AT)).replace("NUM",str(currloadreg - i)) 95 | elif val == 1: 96 | new_block = LOAD_WEIGHT.replace("OFF",str(off )) 97 | new_block_asm = "" 98 | for i in range(CT): 99 | new_block += MAIN_PROGRAM.replace("IDX_1",str(Ny_idx+ i * AT)) 100 | new_block_asm += MAIN_PROGRAM_ASM_ADD.replace("IDX_1",str(Ny_idx +i * AT)).replace("NUM",str(currloadreg - i)) 101 | elif val == -1: 102 | new_block = LOAD_WEIGHT.replace("OFF",str(off )) 103 | new_block_asm = "" 104 | for i in range(CT): 105 | new_block += MAIN_PROGRAM.replace("IDX_1",str(Ny_idx+ i * AT)) 106 | new_block_asm += MAIN_PROGRAM_ASM_SUB.replace("IDX_1",str(Ny_idx +i * AT)).replace("NUM",str(currloadreg - i)) 107 | global AB_vals 108 | AB_vals.append(val) 109 | global A_idx 110 | A_idx.append(Ny_idx) 111 | off += 1 112 | return new_block, new_block_asm 113 | 114 | 115 | def ny_to_a(ny_idx,groupId,blockId, A_dim = None, A_offset = None): 116 | if A_offset is None: 117 | A_offset = blockId * (AT) 118 | return A_offset + ny_idx 119 | 120 | 121 | def generate_from_B(Ny_indices, B_indices,BA,block,NY,BB_offset,GY=None, A_offset=None): 122 | 123 | program = "" 124 | asm = "" 125 | 126 | assert GY == 1 127 | for group in range(GY): 128 | #program += GROUP_CONTROL_START.replace("GROUP",str(group)) + "\n" 129 | 130 | next_tile_start = 0 131 | old_b_idx = -1 132 | 133 | if AVX512: 134 | asm += """ 135 | ..B1.NUM1: 136 | xorl %r10d, %r10d; 137 | ..B1.NUM2: 138 | imul $16, %r10d, %r11d; 139 | add %r9d, %r11d; 140 | movslq %r11d, %r11; 141 | add $CT, %r10d; 142 | """.replace("NUM1",str(BB_offset + block*2+2)).replace("NUM2",str(BB_offset + block * 2 + 3)).replace("STRIDE",str(8)).replace("CT",str(CT)) 143 | else: 144 | asm += """ 145 | ..B1.NUM1: 146 | xorl %r10d, %r10d; 147 | ..B1.NUM2: 148 | lea (%r9,%r10,8), %r11d; 149 | movslq %r11d, %r11; 150 | add $CT, %r10d; 151 | """.replace("NUM1",str(BB_offset + block*2+2)).replace("NUM2",str(BB_offset + block * 2 + 3)).replace("CT",str(CT)) 152 | 153 | 154 | #print(A_offset) 155 | if bias is not None: 156 | for i in range(NY): 157 | for j in range(CT): 158 | if BB_offset > 0: 159 | asm += "\t\tvmovups " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4)" + (",%zmm" if AVX512 else ",%ymm") + str(i + j * AT) + ";\n" 160 | else: 161 | asm += "\tvbroadcastss " + str(mapping[A_offset+i] * 4) + ("(%rsi), %zmm" if AVX512 else "(%rsi), %ymm") + str(i + AT * j) + ";\n" 162 | else: 163 | for i in range(NY): 164 | for j in range(CT): 165 | if BB_offset > 0: 166 | asm += "\t\tvmovups " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4)" + (",%zmm" if AVX512 else ",%ymm") + str(i + j * AT) + ";\n" 167 | else: 168 | asm += "\tvxorps " + ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + "," + \ 169 | ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + "," + ("%zmm" if AVX512 else "%ymm") + str(i + AT * j) + ";\n" 170 | 171 | done = set() 172 | loads = "" 173 | computes = "" 174 | 175 | if AVX512: 176 | TOK = 24 177 | else: 178 | TOK = 13 179 | currloadreg = TOK 180 | for ny_idx, b_idx in zip(Ny_indices[group],B_indices[group]): 181 | 182 | if IN_FORMAT == "NHWC": 183 | if old_b_idx < next_tile_start and b_idx >= next_tile_start: 184 | smem_block = emit_load_smem_block(min(TSB,B_dim - next_tile_start),next_tile_start // TSB) 185 | program += textwrap.indent(smem_block,"\t") 186 | next_tile_start += TSB 187 | 188 | if b_idx != old_b_idx: 189 | if IN_FORMAT == "NCHW": 190 | currloadreg = TOK #(currloadreg - TOK + 1) % 6 + TOK 191 | if currloadreg == TOK: 192 | asm += loads 193 | asm += computes 194 | loads = "" 195 | computes = "" 196 | for i in range(CT): 197 | load_block_cuda, load_block_asm = emit_load_block(b_idx * C_dim + i * VEC, currloadreg - i) 198 | loads += textwrap.indent(load_block_asm,"\t") 199 | program += textwrap.indent(load_block_cuda,"\t") 200 | else: 201 | load_block_cuda, load_block_asm = emit_load_block(b_idx,next_tile_start - TSB) 202 | 203 | 204 | 205 | old_b_idx = b_idx 206 | 207 | a_idx = ny_to_a(ny_idx,group,block,A_dim = A_dim, A_offset=A_offset) 208 | value = BA[b_idx,a_idx] 209 | global B_idx 210 | B_idx.append(b_idx) 211 | compute_block_cuda, compute_block_asm = emit_compute_block(ny_idx , value, currloadreg , virg = ny_idx not in done) 212 | computes += textwrap.indent(compute_block_asm, "\t") 213 | program += textwrap.indent(compute_block_cuda, "\t") 214 | 215 | done.add(ny_idx) 216 | 217 | 218 | 219 | asm += loads 220 | asm += computes 221 | #print(block,group) 222 | #program += GROUP_CONTROL_END + "\n" 223 | global AB_block_offs 224 | AB_block_offs.append(len(AB_vals)) 225 | 226 | return program, asm, done 227 | 228 | 229 | def get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = None, GY=None): 230 | 231 | BA = BA[B_bounds[0]:B_bounds[1]] 232 | Ny_indices = [[] for i in range(GY)] 233 | B_indices = [[] for i in range(GY)] 234 | nnz = np.sum(np.abs(BA[:,A_offset:A_offset + block_NY]) > EPS ) 235 | nnz_per_group = nnz // GY 236 | curr_group = 0 237 | curr_nnz = 0 238 | for B_idx in range(B_dim // B_blocks): 239 | for ny in range(block_NY): 240 | assert curr_group < GY 241 | A_idx = ny_to_a(ny,curr_group,block,A_dim = A_dim, A_offset=A_offset) 242 | if np.abs(BA[B_idx,A_idx]) > EPS: 243 | B_indices[curr_group].append(B_idx + B_bounds[0]) 244 | Ny_indices[curr_group].append(ny) 245 | curr_nnz += 1 246 | if curr_nnz > nnz_per_group: 247 | curr_group += 1 248 | curr_nnz = 0 249 | 250 | return Ny_indices, B_indices 251 | 252 | def no_load_balance(BA,A_blocks): 253 | 254 | #assert A_dim % A_blocks == 0 255 | interval = AT 256 | 257 | bounds = [interval * i for i in range(A_blocks)] + [A_dim] 258 | 259 | return bounds , interval 260 | 261 | def load_balancer2(BA): 262 | 263 | total_nnz = (np.abs(BA) > EPS).sum() 264 | nnz_per_block = total_nnz / A_blocks 265 | sums = np.sum(np.abs(BA) > EPS, axis = 0) 266 | cs = np.cumsum(sums) 267 | bounds = [np.argmax(cs > nnz_per_block * i) for i in range(A_blocks)] 268 | bounds = bounds + [A_dim] 269 | nnzs = np.diff(bounds) 270 | NY = np.max(nnzs) 271 | return bounds, NY 272 | 273 | 274 | # name is the name of the numpy file 275 | def gencode(BA,C_dim,A_blocks,C_blocks,name=None): 276 | program = "" 277 | asm_program = """ 278 | # -- Begin _FUNCNAME 279 | .text 280 | # mark_begin; 281 | .align 16,0x90 282 | .globl _FUNCNAME 283 | # --- mm(void *) 284 | _FUNCNAME: 285 | # parameter 1: %rdi 286 | ..B1.1: # Preds ..B1.0 287 | # Execution count [9.00e-01] 288 | .cfi_startproc 289 | ..___tag_value__FUNCNAME.1: 290 | ..L2: 291 | #45.1 292 | pushq %rbp #45.1 293 | .cfi_def_cfa_offset 16 294 | movq %rsp, %rbp #45.1 295 | .cfi_def_cfa 6, 16 296 | .cfi_offset 6, -16 297 | andq $-32, %rsp #45.1 298 | subq $96, %rsp #45.1 299 | movq (%rdi), %rcx #47.38 300 | movq 8(%rdi), %rsi #48.46 301 | movq 16(%rdi), %r8 #49.41 302 | movq 24(%rdi), %rdx #50.22 303 | movl 36(%rdi), %eax 304 | movl 32(%rdi), %edi #51.21 305 | vxorps ZERO, ZERO, ZERO #59.19 306 | decl %eax 307 | decl %edi 308 | imul $TSZ, %eax, %r9d 309 | 310 | 311 | 312 | """.replace("TSZ",str(TSZ)).replace("ZERO","%zmm30" if AVX512 else "%ymm14").replace("FUNCNAME",FUNC_NAME) 313 | 314 | #assert A_dim % A_blocks == 0 315 | #assert C_dim % C_blocks == 0 316 | B_dim = BA.shape[0] 317 | 318 | # if IN_FORMAT == "NCHW" and OUT_FORMAT == "NCHW": 319 | # bounds, NY = load_balancer2(BA) 320 | # else: 321 | bounds, NY = no_load_balance(BA,A_blocks) 322 | 323 | program += START_NONFUSED.replace("OUTPUT_FORMAT",OUT_FORMAT).replace("INPUT_FORMAT",IN_FORMAT).replace("Ny",str(NY)).replace("GY",str(GY)).replace("A_dim",str(A_dim)).replace( 324 | "C_dim",str(C_dim)).replace("B_dim",str(B_dim)).replace("A_BLOCKS",str(A_blocks)).replace("C_BLOCKS",str(C_blocks)).replace("X86_DEF",str(int(X86))).replace("ARM_DEF",str(int(ARM))) + "\n" 325 | 326 | assert B_dim % B_blocks == 0 327 | block_size = B_dim // B_blocks 328 | for b_block in range(B_blocks): 329 | bb_offset = b_block * A_blocks * 2 330 | for block in range(A_blocks): 331 | A_offset = bounds[block] 332 | block_NY = bounds[block+1] - A_offset 333 | program += BLOCK_CONTROL_START.replace("BLOCK", str(block)).replace("Ny",str(block_NY)) + "\n" 334 | 335 | 336 | Ny_indices, B_indices = get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [b_block * block_size, (b_block + 1) * block_size],GY=GY) 337 | #import pdb;pdb.set_trace() 338 | ccode, asm, done = generate_from_B(Ny_indices,B_indices,BA,block,block_NY,bb_offset, GY=GY,A_offset=A_offset) 339 | #ccode = generate_c_stem(block_NY) 340 | 341 | program += textwrap.indent(ccode,"\t") + "\n" 342 | asm_program += textwrap.indent(asm,"\t") + "\n" 343 | if OUT_FORMAT == "NCHW": 344 | if FUSE_END: 345 | if GY > 1: 346 | print("End fusion strategy not valid.") 347 | for i in range(block_NY): 348 | program += BLOCK_END_REDUCTION.replace("OFFSET",str(mapping[A_offset + i] * C_dim)).replace("IDX",str(i)).replace("BIAS",str(A_offset+i)) 349 | if not NO_RELU: 350 | for j in range(CT): 351 | asm_program += "\t\tvmaxps %ymm" + str(i + j * AT) + (", %zmm30," if AVX512 else ", %ymm14,") + "%ymm" + str(i + j * AT) + ";\n" 352 | for j in range(CT): 353 | # vcmpltps %ymm0, %ymm14, %ymm0 354 | # vmovups msg(%rip), %ymm14 355 | # vminps %ymm14, %ymm0, %ymm0 356 | 357 | asm_program += "\t\tvmovups %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 358 | 359 | asm_program += """ 360 | cmpl $END, %r10d; 361 | jb ..B1.NUM; 362 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 363 | program += "\t}" 364 | else: 365 | if EPI == "LT": 366 | for i in range(block_NY): 367 | for j in range(CT): 368 | asm_program += "\t\tvbroadcastss " + str(mapping[A_offset+i] * 4) + "(%rsi), " + ("%zmm30;" if AVX512 else " %ymm14;") + "\n" 369 | asm_program += "\t\tvcmpltps %ymm14, %ymm" + str(i + j * AT) + ",%ymm" + str(i + j * AT) + ";\n" 370 | asm_program += "\t\tvmovups msg(%rip), %ymm14;\n" 371 | asm_program += "\t\tvminps %ymm14, %ymm" +str(i + j * AT) + ", %ymm" + str(i + j * AT) + ";\n" 372 | asm_program += "\t\tvmovups %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 373 | elif EPI == "EQ": 374 | for i in range(block_NY): 375 | for j in range(CT): 376 | asm_program += "\t\tvbroadcastss " + str(mapping[A_offset+i] * 4) + "(%rsi), " + ("%zmm30;" if AVX512 else " %ymm14;") + "\n" 377 | asm_program += "\t\tvcmpeqps %ymm14, %ymm" + str(i + j * AT) + ",%ymm" + str(i + j * AT) + ";\n" 378 | asm_program += "\t\tvmovups msg(%rip), %ymm14;\n" 379 | asm_program += "\t\tvminps %ymm14, %ymm" +str(i + j * AT) + ", %ymm" + str(i + j * AT) + ";\n" 380 | asm_program += "\t\tvmovups %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 381 | elif EPI == "NONE": 382 | for i in range(block_NY): 383 | for j in range(CT): 384 | asm_program += "\t\tvmovups %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * C_dim * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 385 | asm_program += """ 386 | cmpl $END, %r10d; 387 | jb ..B1.NUM; 388 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 389 | else: 390 | program += BLOCK_END_NHWC.replace("A_offset",str(A_offset)).replace("Ny",str(block_NY)).replace("A_BLOCKS",str(A_blocks)).replace( 391 | "C_BLOCKS", str(C_blocks)).replace("A_dim",str(A_dim)).replace("C_dim",str(C_dim)).replace("B_dim",str(B_dim)) + "\n" 392 | # program += BLOCK_CONTROL_END 393 | 394 | program += END_NONFUSED.replace("AB_sparse_tidy.npy",name) 395 | open(outfile,"w").write(program.replace("B_dim",str(B_dim))) 396 | asm_program += """ 397 | ..B1.NUM1: # Preds ..B1.17 398 | # Execution count [2.80e+01] 399 | decl %eax #44.37 400 | subl $TSZ, %r9d #44.37 401 | cmpl %eax, %edi #44.33 402 | jl ..B1.2 # Prob 96% #44.33 403 | # LOE rcx rbx rbp rsi rdi r12 r13 r14 r15 eax dl ymm15 404 | ..B1.NUM2: # Preds ..B1.18 405 | # Execution count [1.00e+00] 406 | vzeroupper #2398.1 407 | movq %rbp, %rsp 408 | popq %rbp 409 | #call pthread_exit@PLT #2416.1 410 | ret 411 | ..___tag_value__FUNCNAME.13: 412 | .align 16,0x90 413 | # LOE 414 | .cfi_endproc 415 | # mark_end; 416 | .type _FUNCNAME,@function 417 | .size _FUNCNAME,.-_FUNCNAME 418 | ..LN_FUNCNAME.0: 419 | .section .rodata 420 | .balign 32 421 | msg: 422 | .long 0x3f800000,0x3f800000,0x3f800000,0x3f800000,0x3f800000,0x3f800000,0x3f800000,0x3f800000 423 | # -- End _FUNCNAME 424 | 425 | 426 | 427 | """.replace("FUNCNAME",FUNC_NAME).replace("TSZ",str(TSZ)).replace("CBLOCKS",str(C_blocks)).replace("NUM1",str(B_blocks * A_blocks *2 + 2)).replace("NUM2",str(B_blocks * A_blocks * 2 + 3)) 428 | 429 | if AVX512: 430 | asm_program = asm_program.replace("ymm","zmm") 431 | 432 | open(outfile_asm,"w").write(asm_program) 433 | 434 | 435 | if __name__ == "__main__": 436 | 437 | import argparse 438 | parser = argparse.ArgumentParser(description='CodeGen V1') 439 | 440 | parser.add_argument('--A_dim', type=int, default=12) 441 | parser.add_argument('--B_dim', type=int, default=12) 442 | parser.add_argument('--C_dim', type=int, default=12) 443 | parser.add_argument('--AT', type=int, default=12) 444 | parser.add_argument('--C_blocks', type=int, default=12) 445 | parser.add_argument('--CT',type=int, default=1) 446 | parser.add_argument('--B_blocks',type=int,default=1) 447 | parser.add_argument('--Gy', type=int, default=12) 448 | parser.add_argument('--infile', default=None, type=str) 449 | parser.add_argument('--infile_bias', default=None, type=str) 450 | parser.add_argument('--outfile', default=None, type=str) 451 | parser.add_argument('--outfile_asm', default= None, type = str) 452 | parser.add_argument('--in_format', default="NCHW",type=str) 453 | parser.add_argument('--out_format', default="NCHW",type=str) 454 | parser.add_argument('--Tsb',type=float,default=1) 455 | parser.add_argument('--fuse',default=False,action='store_true') 456 | parser.add_argument('--x86',default=False,action='store_true') 457 | parser.add_argument('--arm',default=False,action='store_true') 458 | parser.add_argument('--avx512',default=False,action='store_true') 459 | parser.add_argument('--no_relu',default=False,action='store_true') 460 | parser.add_argument('--no_row_skip',default=False,action='store_true') 461 | args = parser.parse_args() 462 | GY = args.Gy 463 | FUSE_END = args.fuse 464 | NO_RELU = args.no_relu 465 | TSB_MULT = args.Tsb 466 | A_dim = args.A_dim 467 | B_dim = args.B_dim 468 | C_dim = args.C_dim 469 | AT = args.AT 470 | C_blocks = args.C_blocks 471 | AVX512 = False 472 | if args.avx512: 473 | AVX512 = True 474 | input_file = args.infile 475 | outfile = args.outfile 476 | outfile_asm = args.outfile_asm 477 | #assert C_dim % C_blocks == 0 478 | 479 | TSZ = C_dim // C_blocks if C_dim % C_blocks == 0 else C_dim // C_blocks + 1 480 | 481 | X86 = args.x86 482 | ARM = args.arm 483 | NRS = args.no_row_skip 484 | CT = args.CT 485 | B_blocks = args.B_blocks 486 | 487 | if AVX512: 488 | VEC = 16 489 | else: 490 | VEC = 8 491 | 492 | IN_FORMAT = args.in_format 493 | OUT_FORMAT = args.out_format 494 | 495 | input_file_bias = args.infile_bias 496 | if input_file_bias: 497 | bias = np.load(input_file_bias) 498 | else: 499 | bias = None 500 | 501 | #global AB_vals 502 | AB_vals = [] 503 | A_idx = [] 504 | B_idx = [] 505 | AB_block_offs = [0] 506 | #global off 507 | off = 0 508 | 509 | assert not (X86 and ARM) 510 | if X86: 511 | print("Generating X86 vector intrinsics") 512 | elif ARM: 513 | print("Generating Arm intrinsics") 514 | else: 515 | assert False 516 | 517 | if IN_FORMAT == "NHWC" or OUT_FORMAT == "NHWC": 518 | assert False 519 | 520 | BA = np.load(input_file) 521 | print(BA.shape) 522 | BA = BA.squeeze() 523 | 524 | """ 525 | We are going to redo BA here to remove some empty rows 526 | """ 527 | if NRS: 528 | A_dim = BA.shape[1] 529 | mapping = {i:i for i in range(A_dim)} 530 | else: 531 | nnz_cols = np.unique(np.where(BA)[1]) 532 | mapping = {i : nnz_cols[i] for i in range(len(nnz_cols))} 533 | #print(mapping) 534 | BA = BA[:,nnz_cols] 535 | A_dim = len(nnz_cols) 536 | 537 | if A_dim % AT == 0: 538 | A_blocks = A_dim // AT 539 | else: 540 | A_blocks = A_dim // AT + 1 541 | 542 | 543 | print("Reduced A dimension " + str(A_dim)) 544 | gencode(BA,C_dim,A_blocks,C_blocks,name=input_file) 545 | np.save("AB_vals.npy",np.array(AB_vals)) 546 | np.save("AB_block_off.npy",np.array(AB_block_offs).astype(np.int32)) 547 | np.save("A_idx.npy",np.array(A_idx).astype(np.int32)) 548 | np.save("B_idx.npy",np.array(B_idx).astype(np.int32)) 549 | 550 | if input_file_bias: 551 | np.save("AB_bias.npy",bias.squeeze()) 552 | -------------------------------------------------------------------------------- /code_gen_cpu_conv.py: -------------------------------------------------------------------------------- 1 | # this program basically does a constexpr and generates cuda code 2 | import textwrap 3 | import numpy as np 4 | from code_fragments import * 5 | from utils import * 6 | 7 | 8 | import argparse 9 | parser = argparse.ArgumentParser(description='CodeGen V1') 10 | 11 | parser.add_argument('--OC', type=int, default=12) 12 | parser.add_argument('--IC', type=int, default=12) 13 | parser.add_argument('--image_x', type=int, default=12) 14 | parser.add_argument('--image_y', type=int, default=12) 15 | parser.add_argument('--B_blocks', type=int, default=1) 16 | parser.add_argument('--C_blocks', type=int, default=1) 17 | parser.add_argument('--Gy', type=int, default=12) 18 | parser.add_argument('--infile', default=None, type=str) 19 | parser.add_argument('--infile_bias', default=None, type=str) 20 | parser.add_argument('--outfile', default=None, type=str) 21 | parser.add_argument('--outfile_asm', default= None, type = str) 22 | parser.add_argument('--in_format', default="NCHW",type=str) 23 | parser.add_argument('--out_format', default="NCHW",type=str) 24 | parser.add_argument('--Tsb',type=float,default=1) 25 | parser.add_argument('--fuse',default=False,action='store_true') 26 | parser.add_argument('--x86',default=False,action='store_true') 27 | parser.add_argument('--arm',default=False,action='store_true') 28 | parser.add_argument('--threads',type = int, default=4) 29 | parser.add_argument('--avx512',default=False,action='store_true') 30 | parser.add_argument('--no_relu',default=False,action='store_true') 31 | parser.add_argument('--no_row_skip',default=False,action='store_true') 32 | args = parser.parse_args() 33 | FILTER_X = 3 34 | FILTER_Y = 3 35 | GY = args.Gy 36 | FUSE_END = args.fuse 37 | NO_RELU = args.no_relu 38 | print(FUSE_END) 39 | TSB_MULT = args.Tsb 40 | A_dim = args.OC 41 | OC = args.OC 42 | B_dim = args.IC * FILTER_X * FILTER_Y 43 | IC = args.IC 44 | IMAGE_X = args.image_x 45 | IMAGE_Y = args.image_y 46 | AVX512 = False 47 | if args.avx512: 48 | AVX512 = True 49 | 50 | if AVX512: 51 | VEC = 16 52 | else: 53 | VEC = 8 54 | 55 | if IMAGE_Y == 56: 56 | if AVX512: 57 | Y_PAD = 8 58 | else: 59 | Y_PAD = 2 60 | elif IMAGE_Y == 28: 61 | Y_PAD = 4 62 | elif IMAGE_Y == 14: 63 | Y_PAD = 2 64 | elif IMAGE_Y == 7: 65 | Y_PAD = 1 66 | 67 | if IMAGE_Y % VEC != 0: 68 | C_dim = (IMAGE_X ) * (IMAGE_Y + Y_PAD) 69 | else: 70 | C_dim = (IMAGE_X) * IMAGE_Y 71 | 72 | THREADS = args.threads 73 | C_blocks = args.C_blocks 74 | B_blocks = args.B_blocks 75 | 76 | input_file = args.infile 77 | outfile = args.outfile 78 | outfile_asm = args.outfile_asm 79 | #assert C_dim % C_blocks == 0 80 | GSY = C_dim // C_blocks 81 | TSB =int( GSY * TSB_MULT) 82 | TSZ = C_dim // C_blocks if C_dim % C_blocks == 0 else C_dim // C_blocks + 1 83 | 84 | X86 = args.x86 85 | ARM = args.arm 86 | NRS = args.no_row_skip 87 | 88 | if AVX512: 89 | if IMAGE_Y == 56: 90 | AT = 6 91 | CT = 4 92 | elif IMAGE_Y == 28: 93 | AT = 12 94 | CT = 2 95 | elif IMAGE_Y == 14: 96 | AT = 24 97 | CT = 1 98 | elif IMAGE_Y == 7: 99 | print("Not supported") 100 | else: 101 | print("Not supported") 102 | exit() 103 | else: 104 | if IMAGE_Y == 56: 105 | AT = 1 106 | CT = 7 107 | elif IMAGE_Y == 28: 108 | AT = 2 109 | CT = 4 110 | elif IMAGE_Y == 14: 111 | AT = 6 112 | CT = 2 113 | elif IMAGE_Y == 7: 114 | AT = 12 115 | CT = 1 116 | else: 117 | print("Not supported") 118 | exit() 119 | 120 | 121 | assert not (X86 and ARM) 122 | if X86: 123 | print("Generating X86 vector intrinsics") 124 | elif ARM: 125 | print("Generating Arm intrinsics") 126 | else: 127 | assert False 128 | 129 | IN_FORMAT = args.in_format 130 | OUT_FORMAT = args.out_format 131 | 132 | if IN_FORMAT == "NHWC" or OUT_FORMAT == "NHWC": 133 | assert False 134 | 135 | input_file_bias = args.infile_bias 136 | if input_file_bias: 137 | bias = np.load(input_file_bias) 138 | 139 | #global AB_vals 140 | AB_vals = [] 141 | #global off 142 | off = 0 143 | 144 | if X86: 145 | if AVX512: 146 | LOAD_CACHE = """ 147 | RC = _mm512_load_ps(&BC[IDX + C_offset + lane ]); 148 | 149 | //RC = _mm256_load_ps(&BC[(C_offset + lane) * B_dim + IDX]); 150 | """ 151 | LOAD_CACHE_ASM = """vmovups IDX(%r8,%r11,4), %zmmNUM; 152 | """ 153 | else: 154 | LOAD_CACHE = """ 155 | RC = _mm256_load_ps(&BC[IDX + C_offset + lane + lane / 28 * 2]); 156 | 157 | //RC = _mm256_load_ps(&BC[(C_offset + lane) * B_dim + IDX]); 158 | """ 159 | LOAD_CACHE_ASM = """vmovups IDX(%r8,%r11,4), %ymmNUM; 160 | """ 161 | elif ARM: 162 | LOAD_CACHE = """ 163 | RC = vld1q_f32(&BC[IDX + C_offset + lane]); 164 | """ 165 | 166 | 167 | if X86: 168 | if AVX512: 169 | LOAD_WEIGHT=""" 170 | val = _mm512_broadcast_ss(AB_val + OFF); 171 | """ 172 | MAIN_PROGRAM =""" 173 | //val = _mm256_set1_ps(VAL); 174 | ACC[IDX_1] = _mm512_fmadd_ps(RC, val, ACC[IDX_1]); 175 | """ 176 | LOAD_WEIGHT_ASM = """vbroadcastss OFF(%rcx), %zmm31; 177 | """ 178 | MAIN_PROGRAM_ASM="""vfmadd231ps %zmmNUM, %zmm31, %zmmIDX_1; 179 | """ 180 | 181 | else: 182 | LOAD_WEIGHT = """ 183 | val = _mm256_broadcast_ss(AB_val + OFF); 184 | """ 185 | MAIN_PROGRAM =""" 186 | ACC[IDX_1] = _mm256_fmadd_ps(RC, val, ACC[IDX_1]); 187 | """ 188 | LOAD_WEIGHT_ASM = """vbroadcastss OFF(%rcx), %ymm15; 189 | """ 190 | MAIN_PROGRAM_ASM="""vfmadd231ps %ymmNUM, %ymm15, %ymmIDX_1; 191 | """ 192 | MAIN_PROGRAM_ASM_VIRG="""vbroadcastss OFF(%rcx), %ymm15; 193 | vmul231ps %ymmNUM, %ymm15, %ymmIDX_1; 194 | """ 195 | 196 | elif ARM: 197 | MAIN_PROGRAM =""" 198 | val = vdupq_n_f32(VAL); 199 | ACC[IDX_1] = vmlaq_f32(ACC[IDX_1], RC, val); 200 | """ 201 | 202 | 203 | if IN_FORMAT == "NCHW": 204 | def emit_load_block(index, currloadreg): 205 | new_block = LOAD_CACHE.replace("IDX",str(index)) 206 | new_block_asm = LOAD_CACHE_ASM.replace("IDX",str(index * 4)).replace("NUM",str(currloadreg)) 207 | #new_block = LOAD_CACHE.replace("IDX",str(B_idx * 8)) 208 | return new_block, new_block_asm 209 | else: 210 | def emit_load_block(B_idx,B_offset): 211 | new_block = LOAD_CACHE.replace("IDX",str(B_idx - B_offset)) 212 | return new_block 213 | 214 | def emit_load_smem_block(local_TSB, tile_id): 215 | return LOAD_SHARED.replace("TSB",str(local_TSB)).replace("TILE",str(tile_id * TSB)) 216 | 217 | def emit_compute_block(Ny_idx,val,currloadreg, virg=False): 218 | global off 219 | new_block = LOAD_WEIGHT.replace("OFF",str(off )) 220 | new_block_asm = LOAD_WEIGHT_ASM.replace("OFF",str(off * 4 )) 221 | for i in range(CT): 222 | new_block += MAIN_PROGRAM.replace("IDX_1",str(Ny_idx+ i * AT)) 223 | new_block_asm += MAIN_PROGRAM_ASM.replace("IDX_1",str(Ny_idx +i * AT)).replace("NUM",str(currloadreg - i)) 224 | global AB_vals 225 | AB_vals.append(val) 226 | off += 1 227 | return new_block, new_block_asm 228 | 229 | 230 | def ny_to_a(ny_idx,groupId,blockId, A_dim = None, A_offset = None): 231 | if A_offset is None: 232 | A_offset = blockId * (AT) 233 | return A_offset + ny_idx 234 | 235 | 236 | def generate_from_B(Ny_indices, B_indices,BA,block,NY,BB_offset, GY = None,A_offset=None): 237 | 238 | program = "" 239 | asm = "" 240 | 241 | assert GY == 1 242 | for group in range(GY): 243 | #program += GROUP_CONTROL_START.replace("GROUP",str(group)) + "\n" 244 | 245 | next_tile_start = 0 246 | old_b_idx = -1 247 | 248 | if AVX512: 249 | asm += """ 250 | ..B1.NUM1: 251 | xorl %r10d, %r10d; 252 | #xorl %r13d, %r13d; 253 | ..B1.NUM2: 254 | imul $16, %r10d, %r11d; 255 | add %r9d, %r11d; 256 | movslq %r11d, %r11; 257 | #add $CT, %r10d; 258 | """.replace("NUM1",str(BB_offset + block*2+2)).replace("NUM2",str(BB_offset + block * 2 + 3)) 259 | else: 260 | asm += """ 261 | ..B1.NUM1: 262 | #xorl %r13d, %r13d; 263 | xorl %r10d, %r10d; 264 | ..B1.NUM2: 265 | lea (%r9,%r10,8), %r11d; 266 | movslq %r11d, %r11; 267 | """.replace("NUM1",str(BB_offset + block*2 + 2)).replace("NUM2",str(BB_offset + block * 2 + 3)).replace("CT",str(CT)) 268 | 269 | if not AVX512 and IMAGE_Y == 56: 270 | asm += """ 271 | mov %r10d, %r13d; 272 | mov %r10d, %r14d; 273 | imul $613566757, %r13, %r13; 274 | shr $32, %r13; 275 | sub %r13d, %r14d; 276 | shr %r14d; 277 | add %r14d, %r13d; 278 | shr $2, %r13d; 279 | shl $1, %r13d; 280 | add %r13d, %r11d; 281 | """ 282 | 283 | #print(A_offset) 284 | 285 | if AVX512: 286 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 56: 287 | asm += """sub %r10d, %r11d; 288 | sub %r10d, %r11d; 289 | """ 290 | else: 291 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 7: 292 | asm += """sub %r10d, %r11d; 293 | """ 294 | for i in range(NY): 295 | for j in range(CT): 296 | if BB_offset > 0: 297 | asm += "\tvmovups " + str(mapping[A_offset + i] * IMAGE_Y * IMAGE_Y * 4 + j * VEC * 4) + "(%rdx,%r11,4)" + "," + "%ymm" + str(i + AT * j) + ";\n" 298 | else: 299 | asm += "\tvbroadcastss " + str(mapping[A_offset+i] * 4) + ("(%rsi), %zmm" if AVX512 else "(%rsi), %ymm") + str(i + AT * j) + ";\n" 300 | if AVX512: 301 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 56: 302 | asm += """add %r10d, %r11d; 303 | add %r10d, %r11d; 304 | """ 305 | else: 306 | 307 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 7: 308 | asm += """add %r10d, %r11d; 309 | """ 310 | done = set() 311 | loads = "" 312 | computes = "" 313 | 314 | if AVX512: 315 | TOK = 29 316 | else: 317 | TOK = 13 318 | currloadreg = TOK 319 | for ny_idx, b_idx in zip(Ny_indices[group],B_indices[group]): 320 | 321 | if IN_FORMAT == "NHWC": 322 | if old_b_idx < next_tile_start and b_idx >= next_tile_start: 323 | smem_block = emit_load_smem_block(min(TSB,B_dim - next_tile_start),next_tile_start // TSB) 324 | program += textwrap.indent(smem_block,"\t") 325 | next_tile_start += TSB 326 | 327 | if b_idx != old_b_idx: 328 | if IN_FORMAT == "NCHW": 329 | currloadreg = TOK #(currloadreg - TOK + 1) % 6 + TOK 330 | if currloadreg == TOK: 331 | asm += loads 332 | asm += computes 333 | loads = "" 334 | computes = "" 335 | 336 | channel = b_idx // (FILTER_X * FILTER_Y) 337 | x = (b_idx // FILTER_Y) % FILTER_X 338 | y = b_idx % FILTER_Y 339 | channel_offset = channel * (IMAGE_X+1) * (IMAGE_Y+Y_PAD) - IMAGE_Y -Y_PAD - 1 340 | for i in range(CT): 341 | load_block_cuda, load_block_asm = emit_load_block(channel_offset + x * (IMAGE_Y+Y_PAD) + y + i * VEC, currloadreg - i) 342 | loads += textwrap.indent(load_block_asm,"\t") 343 | program += textwrap.indent(load_block_cuda,"\t") 344 | else: 345 | load_block_cuda, load_block_asm = emit_load_block(b_idx,next_tile_start - TSB) 346 | 347 | 348 | 349 | old_b_idx = b_idx 350 | 351 | a_idx = ny_to_a(ny_idx,group,block,A_dim = A_dim, A_offset=A_offset) 352 | value = BA[b_idx,a_idx] 353 | 354 | compute_block_cuda, compute_block_asm = emit_compute_block(ny_idx , value, currloadreg , virg = ny_idx not in done) 355 | computes += textwrap.indent(compute_block_asm, "\t") 356 | program += textwrap.indent(compute_block_cuda, "\t") 357 | 358 | done.add(ny_idx) 359 | 360 | 361 | 362 | asm += loads 363 | asm += computes 364 | #print(block,group) 365 | #program += GROUP_CONTROL_END + "\n" 366 | 367 | return program, asm, done 368 | 369 | 370 | def get_idx_balanced(block,BA,A_offset,block_NY,B_bounds = [0,B_dim], GY=None): 371 | 372 | BA = BA[B_bounds[0]:B_bounds[1]] 373 | Ny_indices = [[] for i in range(GY)] 374 | B_indices = [[] for i in range(GY)] 375 | nnz = np.sum(np.abs(BA[:,A_offset:A_offset + block_NY]) > EPS ) 376 | nnz_per_group = nnz // GY 377 | curr_group = 0 378 | curr_nnz = 0 379 | for B_idx in range(B_dim // B_blocks): 380 | for ny in range(block_NY): 381 | assert curr_group < GY 382 | A_idx = ny_to_a(ny,curr_group,block,A_dim = A_dim, A_offset=A_offset) 383 | if np.abs(BA[B_idx,A_idx]) > EPS: 384 | B_indices[curr_group].append(B_idx + B_bounds[0]) 385 | Ny_indices[curr_group].append(ny) 386 | curr_nnz += 1 387 | if curr_nnz > nnz_per_group: 388 | curr_group += 1 389 | curr_nnz = 0 390 | 391 | return Ny_indices, B_indices 392 | 393 | def no_load_balance(BA): 394 | 395 | #assert A_dim % A_blocks == 0 396 | interval = AT 397 | 398 | bounds = [interval * i for i in range(A_blocks)] + [A_dim] 399 | 400 | return bounds , interval 401 | 402 | def load_balancer2(BA): 403 | 404 | total_nnz = (np.abs(BA) > EPS).sum() 405 | nnz_per_block = total_nnz / A_blocks 406 | sums = np.sum(np.abs(BA) > EPS, axis = 0) 407 | cs = np.cumsum(sums) 408 | bounds = [np.argmax(cs > nnz_per_block * i) for i in range(A_blocks)] 409 | bounds = bounds + [A_dim] 410 | nnzs = np.diff(bounds) 411 | NY = np.max(nnzs) 412 | return bounds, NY 413 | 414 | 415 | # name is the name of the numpy file 416 | def gencode(BA,outfile,C_dim,A_blocks,C_blocks,GY,name=None): 417 | program = "" 418 | asm_program = """ 419 | # -- Begin _Z2mmPv 420 | .text 421 | # mark_begin; 422 | .align 16,0x90 423 | .globl _Z2mmPv 424 | # --- mm(void *) 425 | _Z2mmPv: 426 | # parameter 1: %rdi 427 | ..B1.1: # Preds ..B1.0 428 | # Execution count [9.00e-01] 429 | .cfi_startproc 430 | ..___tag_value__Z2mmPv.1: 431 | ..L2: 432 | #45.1 433 | pushq %rbp #45.1 434 | .cfi_def_cfa_offset 16 435 | movq %rsp, %rbp #45.1 436 | .cfi_def_cfa 6, 16 437 | .cfi_offset 6, -16 438 | andq $-32, %rsp #45.1 439 | subq $96, %rsp #45.1 440 | movq (%rdi), %rcx #47.38 441 | movq 8(%rdi), %rsi #48.46 442 | movq 16(%rdi), %r8 #49.41 443 | movq 24(%rdi), %rdx #50.22 444 | movl 36(%rdi), %eax 445 | movl 32(%rdi), %edi #51.21 446 | vxorps ZERO, ZERO, ZERO #59.19 447 | decl %eax 448 | decl %edi 449 | imul $TSZ, %eax, %r9d 450 | 451 | 452 | 453 | """.replace("BOUND",str(C_blocks//THREADS)).replace("TSZ",str(TSZ)).replace("ZERO","%zmm30" if AVX512 else "%ymm14") 454 | 455 | #assert A_dim % A_blocks == 0 456 | #assert C_dim % C_blocks == 0 457 | B_dim = BA.shape[0] 458 | 459 | # if IN_FORMAT == "NCHW" and OUT_FORMAT == "NCHW": 460 | # bounds, NY = load_balancer2(BA) 461 | # else: 462 | bounds, NY = no_load_balance(BA) 463 | 464 | program += START_NONFUSED.replace("OUTPUT_FORMAT",OUT_FORMAT).replace("INPUT_FORMAT",IN_FORMAT).replace("Ny",str(NY)).replace("GY",str(GY)).replace("A_dim",str(A_dim)).replace( 465 | "C_dim",str(C_dim)).replace("B_dim",str(B_dim)).replace("A_BLOCKS",str(A_blocks)).replace("C_BLOCKS",str(C_blocks)).replace("BOUND",str(C_blocks//4)).replace("X86_DEF",str(int(X86))).replace("ARM_DEF",str(int(ARM))) + "\n" 466 | 467 | assert B_dim % B_blocks == 0 468 | block_size = B_dim // B_blocks 469 | for b_block in range(B_blocks): 470 | bb_offset = b_block * A_blocks * 2 471 | for block in range(A_blocks): 472 | A_offset = bounds[block] 473 | block_NY = bounds[block+1] - A_offset 474 | program += BLOCK_CONTROL_START.replace("BLOCK", str(block)).replace("Ny",str(block_NY)) + "\n" 475 | 476 | 477 | Ny_indices, B_indices = get_idx_balanced(block,BA,A_offset,block_NY,B_bounds=[b_block * block_size, (b_block+1) * block_size ],GY=GY) 478 | #import pdb;pdb.set_trace() 479 | ccode, asm, done = generate_from_B(Ny_indices,B_indices,BA,block,block_NY,bb_offset,GY=GY,A_offset=A_offset) 480 | #ccode = generate_c_stem(block_NY) 481 | 482 | program += textwrap.indent(ccode,"\t") + "\n" 483 | asm_program += textwrap.indent(asm,"\t") + "\n" 484 | if OUT_FORMAT == "NCHW": 485 | if FUSE_END: 486 | 487 | if AVX512: 488 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 56: 489 | asm_program += "sub %r10d, %r11d;\n\tsub %r10d, %r11d;" 490 | else: 491 | if IMAGE_Y == 28 or IMAGE_Y == 14 or IMAGE_Y == 7: 492 | asm_program += "sub %r10d, %r11d;" 493 | 494 | asm_program += """ 495 | add $CT, %r10d; 496 | 497 | """.replace("CT",str(CT)) 498 | if IMAGE_Y % VEC != 0: 499 | if AVX512: 500 | asm_program += " mov msg(%rip), %ebx;\n kmovw %ebx, %k1;" 501 | else: 502 | asm_program += " vmovdqu msg(%rip), %ymm13;\n" 503 | else: 504 | asm_program += " sub %r13d, %r11d;\n" 505 | 506 | if GY > 1: 507 | print("End fusion strategy not valid.") 508 | for i in range(block_NY): 509 | program += BLOCK_END_REDUCTION.replace("OFFSET",str(mapping[A_offset + i] * C_dim)).replace("IDX",str(i)).replace("BIAS",str(A_offset+i)) 510 | 511 | if not NO_RELU and b_block == B_blocks - 1: 512 | for j in range(CT): 513 | asm_program += "\t\tvmaxps %ymm" + str(i + j * AT) + (", %zmm30," if AVX512 else ", %ymm14,") + "%ymm" + str(i + j * AT) + ";\n" 514 | for j in range(CT): 515 | if j == CT - 1 and IMAGE_Y % VEC != 0: 516 | if AVX512: 517 | asm_program += "\t\tvmovups %zmm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * IMAGE_Y * IMAGE_Y * 4 + j * VEC * 4) + "(%rdx,%r11,4) {%k1} ; \n" 518 | else: 519 | asm_program += "\t\tvpmaskmovd %ymm" + str(i + j * AT) + ",%ymm13," + str(mapping[A_offset + i] * IMAGE_Y * IMAGE_Y * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 520 | else: 521 | asm_program += "\t\tvmovups %ymm" + str(i + j * AT) + ", " + str(mapping[A_offset + i] * IMAGE_Y * IMAGE_Y * 4 + j * VEC * 4) + "(%rdx,%r11,4);\n" 522 | asm_program += """ 523 | cmp $END, %r10d; 524 | jb ..B1.NUM; 525 | """.replace("NUM",str(bb_offset + block * 2 + 3)).replace("END",str(TSZ // VEC)) 526 | 527 | program += "\t}" 528 | else: 529 | program += BLOCK_END.replace("A_offset",str(A_offset)).replace("Ny",str(block_NY)).replace("A_BLOCKS",str(A_blocks)).replace( 530 | "C_BLOCKS", str(C_blocks)).replace("A_dim",str(A_dim)).replace("C_dim",str(C_dim)).replace("B_dim",str(B_dim)) + "\n" 531 | else: 532 | program += BLOCK_END_NHWC.replace("A_offset",str(A_offset)).replace("Ny",str(block_NY)).replace("A_BLOCKS",str(A_blocks)).replace( 533 | "C_BLOCKS", str(C_blocks)).replace("A_dim",str(A_dim)).replace("C_dim",str(C_dim)).replace("B_dim",str(B_dim)) + "\n" 534 | # program += BLOCK_CONTROL_END 535 | 536 | program += END_NONFUSED.replace("AB_sparse_tidy.npy",name) 537 | open(outfile,"w").write(program.replace("B_dim",str(B_dim))) 538 | if AVX512: 539 | if IMAGE_Y == 28: 540 | MSG = ".short 0x0fff" 541 | elif IMAGE_Y == 14: 542 | MSG = ".short 0x3fff" 543 | elif IMAGE_Y == 56: 544 | MSG = ".short 0x00ff" 545 | else: 546 | MSG = ".short 0xffff" 547 | 548 | else: 549 | if IMAGE_Y == 28: 550 | MSG = ".int 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000" 551 | elif IMAGE_Y == 14: 552 | MSG = ".int 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0x00000000, 0x00000000" 553 | elif IMAGE_Y == 7: 554 | MSG = ".int 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0x00000000" 555 | else: 556 | MSG = ".int 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000, 0xf0000000" 557 | 558 | 559 | asm_program += """ 560 | ..B1.NUM1: # Preds ..B1.17 561 | # Execution count [2.80e+01] 562 | decl %eax #44.37 563 | subl $TSZ, %r9d #44.37 564 | cmpl %eax, %edi #44.33 565 | jl ..B1.2 # Prob 96% #44.33 566 | # LOE rcx rbx rbp rsi rdi r12 r13 r14 r15 eax dl ymm15 567 | ..B1.NUM2: # Preds ..B1.18 568 | # Execution count [1.00e+00] 569 | vzeroupper #2398.1 570 | movq %rbp, %rsp 571 | popq %rbp 572 | #call pthread_exit@PLT #2416.1 573 | ret 574 | ..___tag_value__Z2mmPv.13: 575 | .align 16,0x90 576 | # LOE 577 | .cfi_endproc 578 | # mark_end; 579 | .type _Z2mmPv,@function 580 | .size _Z2mmPv,.-_Z2mmPv 581 | ..LN_Z2mmPv.0: 582 | .section .rodata 583 | .balign 32 584 | msg: 585 | MSG 586 | 587 | # -- End _Z2mmPv 588 | """.replace("MSG",MSG).replace("TSZ",str(TSZ)).replace("CBLOCKS",str(C_blocks)).replace("NUM1",str(A_blocks *2 * B_blocks + 2)).replace("NUM2",str(A_blocks * 2 * B_blocks + 3)) 589 | 590 | if AVX512: 591 | asm_program = asm_program.replace("ymm","zmm") 592 | 593 | open(outfile_asm,"w").write(asm_program) 594 | 595 | 596 | 597 | BA = np.load(input_file) 598 | BA =BA 599 | print(BA.shape) 600 | BA = BA.squeeze() 601 | 602 | """ 603 | We are going to redo BA here to remove some empty rows 604 | """ 605 | if NRS: 606 | A_dim = BA.shape[1] 607 | mapping = {i:i for i in range(A_dim)} 608 | else: 609 | nnz_cols = np.unique(np.where(BA)[1]) 610 | mapping = {i : nnz_cols[i] for i in range(len(nnz_cols))} 611 | #print(mapping) 612 | BA = BA[:,nnz_cols] 613 | A_dim = len(nnz_cols) 614 | if A_dim % AT == 0: 615 | A_blocks = A_dim // AT 616 | else: 617 | A_blocks = A_dim // AT + 1 618 | 619 | 620 | print("Reduced A dimension " + str(A_dim)) 621 | gencode(BA,outfile,C_dim,A_blocks,C_blocks,GY,name=input_file) 622 | np.save("AB_vals.npy",np.array(AB_vals)) 623 | if input_file_bias: 624 | np.save("AB_bias.npy",bias.squeeze()) 625 | --------------------------------------------------------------------------------