├── test ├── test.sh ├── testhelpers.hpp └── test.cpp ├── benchmark ├── benchmark.sh └── benchmark.cpp ├── jni ├── Application.mk └── Android.mk ├── .gitignore ├── LICENSE ├── README.md ├── arch-generic.hpp ├── arch-neon.hpp └── gemmbitserial.hpp /test/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -f test.out ]; then 4 | rm test.out 5 | fi 6 | g++ -march=native -g -std=c++11 -I.. test.cpp -o test.out 7 | ./test.out 8 | -------------------------------------------------------------------------------- /benchmark/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -f benchmark.out ]; then 4 | rm benchmark.out 5 | fi 6 | g++ -march=native -O3 -std=c++11 -I.. benchmark.cpp -o benchmark.out 7 | ./benchmark.out 8 | -------------------------------------------------------------------------------- /jni/Application.mk: -------------------------------------------------------------------------------- 1 | APP_STL := c++_static 2 | APP_ABI := arm64-v8a 3 | APP_CPPFLAGS := -O3 -funroll-loops -std=c++11 -Wall -Wextra -pedantic -Wno-unused-variable -Wno-unused-parameter -I. 4 | APP_LDFLAGS := -L$(SYSROOT)/usr/lib -lstdc++ -latomic 5 | APP_PIE := true 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | 31 | libs/ 32 | obj/ 33 | -------------------------------------------------------------------------------- /jni/Android.mk: -------------------------------------------------------------------------------- 1 | LOCAL_PATH := $(call my-dir) 2 | 3 | include $(CLEAR_VARS) 4 | 5 | LOCAL_ARM_NEON := true 6 | LOCAL_MODULE := benchmark 7 | LOCAL_SRC_FILES := ../benchmark/benchmark.cpp 8 | #LOCAL_CFLAGS += -save-temps 9 | include $(BUILD_EXECUTABLE) 10 | 11 | include $(CLEAR_VARS) 12 | 13 | LOCAL_ARM_NEON := true 14 | LOCAL_MODULE := test 15 | LOCAL_SRC_FILES := ../test/test.cpp 16 | LOCAL_CFLAGS += -UNDEBUG 17 | include $(BUILD_EXECUTABLE) 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Norwegian University of Science and Technology (NTNU) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gemmbitserial 2 | 3 | gemmbitserial is a simple, header-only C++ library for fast multiplication of few-bit integer matrices. It is primarily intended for running quantized neural network inference on CPUs, which require fast few-bit integer matrix multiplication. 4 | 5 | Documentation is currently underway, and all contributions/suggestions are welcome. 6 | 7 | ## Preliminaries 8 | 9 | It computes the product between two integer matrices A and B, either signed or unsigned. For 1-bit matrices, it supports both bipolar {-1, +1} and regular unsigned {0, 1} encoding. The input matrices must be first converted into bit-serial form via the importRegular function call. 10 | Note that the right-hand-side matrix must be provided in transposed (column-major) form, and the result is also produced in transposed form. 11 | A short paper regarding the underlying operation principle can be found [here](http://www.idi.ntnu.no/~yamanu/2017-cases-wip-quantizedmm-preprint.pdf). 12 | 13 | ## Quickstart 14 | 1) Import "gemmbitserial.h" 15 | 2) Instantiate a GEMMContext by using the allocateGEMMContext function. 16 | 2) Import left-hand-side and right-hand-side matrices by calling gemmcontext.lhs.importRegular and gemmcontext.rhs.importRegular. 17 | 3) Call the gemmBitSerial function with the gemmcontext as the argument. 18 | 4) Done! You can now read out the result from gemmcontext.res 19 | 5) Release the context by calling deallocGEMMContext. 20 | 21 | ## Running benchmarks 22 | There is a little benchmarking tool to quickly evaluate performance on different matrix sizes and bitwidths. So far the code paths are only optimized for ARM and x86 to some extent. To build the benchmarking tool: 23 | 24 | For Android: ndk-build in the root directory, then adb push to e.g. /data/local/tmp and run from there 25 | For x86: cd benchmark; ./benchmark.sh 26 | 27 | Once the interactive benchmark is running, it will ask for the following parameters from stdin: 28 | rows depth columns lhs_bitwidth rhs_bitwidth lhs_signed rhs_signed number_of_seconds_to run 29 | 30 | For instance, entering the following in stdin will run a 8x8192x8 binary unsigned matrix multiply for 20 seconds, and report the GOPS: 31 | 8 8192 8 1 1 0 0 20 32 | -------------------------------------------------------------------------------- /test/testhelpers.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace gemmbitserial { 6 | 7 | // Generate a random vector of -1 and +1 values of given dimension 8 | template 9 | void generateRandomVector_Bipolar(size_t dim, T * ret) { 10 | for(size_t i = 0; i < dim; i++) { 11 | ret[i] = (rand() % 2 == 0) ? 1 : -1; 12 | } 13 | } 14 | 15 | /** 16 | * Generate a random vector with given dimension and number of bits 17 | */ 18 | template 19 | void generateRandomVector(size_t bits, size_t dim, T * ret, bool allowNeg = false) { 20 | assert(bits <= (sizeof(T) * 8)); 21 | if(bits == 1 && allowNeg) { 22 | // generate bipolar values 23 | generateRandomVector_Bipolar(dim, ret); 24 | return; 25 | } 26 | int32_t minVal = 0; 27 | int32_t maxVal = (1 << bits); 28 | for(size_t i = 0; i < dim; i++) { 29 | ret[i] = (rand() % maxVal) - (allowNeg ? maxVal/2 : 0); 30 | } 31 | } 32 | 33 | template 34 | void naive_int_gemm(LHSType * lhs, RHSType * rhs, int32_t * res, int rows, int depth, int cols) { 35 | for(int k = 0; k < cols; k++) { 36 | for(int i = 0; i < rows; i++) { 37 | int32_t acc = 0; 38 | for(int j = 0; j < depth; j++) { 39 | acc += lhs[i * depth + j] * rhs[k * depth + j]; 40 | } 41 | res[k * rows + i] = acc; 42 | } 43 | } 44 | } 45 | 46 | template 47 | void naive_sum_rows(T * m, int32_t * res, int rows, int cols) { 48 | for(int i = 0; i < rows; i++) { 49 | int32_t acc = 0; 50 | for(int k = 0; k < cols; k++) { 51 | acc += m[i * cols + k]; 52 | } 53 | res[i] = acc; 54 | } 55 | } 56 | 57 | template 58 | void printmatrix(T * mat, int rows, int cols) { 59 | for(int i = 0; i < rows; i++) { 60 | for(int j = 0; j < cols; j++) { 61 | std::cout << (int) mat[i * cols + j] << " "; 62 | } 63 | std::cout << std::endl; 64 | } 65 | std::cout << std::endl; 66 | } 67 | 68 | template 69 | void printmatrixdiff(const T * mat1, const T * mat2, int rows, int cols) { 70 | for(int i = 0; i < rows; i++) { 71 | for(int j = 0; j < cols; j++) { 72 | if(mat1[i * cols + j] != mat2[i * cols + j]) { 73 | std::cout << "Difference at (i,j) = " << i << " " << j << " Mat1: " << (int)mat1[i * cols + j] << " Mat2: " << mat2[i * cols + j] << std::endl; 74 | } 75 | } 76 | } 77 | std::cout << std::endl; 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /benchmark/benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "gemmbitserial.hpp" 6 | 7 | using namespace std; 8 | using namespace gemmbitserial; 9 | 10 | /** 11 | * Generate a random vector with given dimension and number of bits <= 8 12 | */ 13 | template 14 | void generateRandomVector(size_t bits, size_t dim, T* ret) { 15 | uint8_t minVal = 0; 16 | uint8_t maxVal = (1 << bits); 17 | for(size_t i = 0; i < dim; i++) { 18 | ret[i] = (T) (rand() % maxVal); 19 | } 20 | } 21 | 22 | void benchmark_unrolledpopcount(size_t numBits, float secs) { 23 | string bench_name = "popcount"; 24 | double opcount = numBits; 25 | uint8_t * rnd_vec = new uint8_t[numBits/8]; 26 | uint64_t * ppcnt = (uint64_t *) rnd_vec; 27 | uint64_t res_pcnt = 0; 28 | cout << "======================================================================" << endl; 29 | cout << numBits << "-bit popcount using __builtin_popcntll, for " << secs << " seconds..." << endl; 30 | unsigned int reps = 0; 31 | auto start = chrono::high_resolution_clock::now(); 32 | auto end = chrono::high_resolution_clock::now(); 33 | while (chrono::duration_cast(end-start).count() < secs) { 34 | // =============== start of benchmark kernel ============= 35 | res_pcnt = 0; 36 | for(unsigned int r = 0; r < numBits/64; r+=4) { 37 | res_pcnt += __builtin_popcountll(ppcnt[r]); 38 | res_pcnt += __builtin_popcountll(ppcnt[r+1]); 39 | res_pcnt += __builtin_popcountll(ppcnt[r+2]); 40 | res_pcnt += __builtin_popcountll(ppcnt[r+3]); 41 | } 42 | // =============== end of benchmark kernel ================ 43 | reps += 1; 44 | end = chrono::high_resolution_clock::now(); 45 | } 46 | cout << "Returned result: " << res_pcnt << endl; 47 | double nscount = chrono::duration_cast(end-start).count() / (double)reps; 48 | double perf = opcount / (nscount); // billion bit operations per second 49 | cout << "Time for a single " << bench_name << ": " << nscount << " nanoseconds" << endl; 50 | cout << "Performance for " << bench_name << ": " << perf << " GOPS per second" << endl; 51 | cout << (numBits/8) / (nscount) << " GB/s" << endl; 52 | delete [] rnd_vec; 53 | } 54 | 55 | void benchmark_gemm_interactive() { 56 | while(1) { 57 | int rows, depth, cols, lhsbits, rhsbits, lhssigned, rhssigned; 58 | float secs; 59 | cout << "Enter rows depth cols, 0 for next benchmark, -1 to exit " << endl; 60 | cin >> rows; 61 | if(rows == 0) { 62 | break; 63 | } else if (rows < 0) { 64 | exit(0); 65 | } 66 | cin >> depth >> cols; 67 | cout << "Enter lhs and rhs bits: " << endl; 68 | cin >> lhsbits >> rhsbits; 69 | cout << "Enter signedness (1 or 0) for lhs and rhs: " << endl; 70 | cin >> lhssigned >> rhssigned; 71 | cout << "Enter number of seconds to benchmark: " << endl; 72 | cin >> secs; 73 | // prepare workload 74 | uint8_t * rnd_matA = new uint8_t[rows*depth]; 75 | uint8_t * rnd_matB = new uint8_t[depth*cols]; 76 | int32_t * res = new int32_t[rows*cols]; 77 | generateRandomVector(lhsbits, rows*depth, rnd_matA); 78 | generateRandomVector(rhsbits, depth*cols, rnd_matB); 79 | 80 | GEMMContext ctx = allocGEMMContext(rows, depth, cols, lhsbits, rhsbits, (bool) lhssigned, (bool) rhssigned); 81 | ctx.lhs.importRegular(rnd_matA); 82 | ctx.rhs.importRegular(rnd_matB); 83 | ctx.printSummary(); 84 | 85 | 86 | delete [] rnd_matA; 87 | delete [] rnd_matB; 88 | cout << "======================================================================" << endl; 89 | char bench_name[1024]; 90 | sprintf(bench_name, "gemm-%d x %d x %d (%d bit x %d bit)", rows, depth, cols, lhsbits, rhsbits); 91 | cout << "Running " << bench_name << " for " << secs << " seconds..." << endl; 92 | unsigned int reps = 0; 93 | auto start = chrono::high_resolution_clock::now(); 94 | auto end = chrono::high_resolution_clock::now(); 95 | while (chrono::duration_cast(end-start).count() < secs) { 96 | // =============== start of benchmark kernel ============= 97 | gemmBitSerial(ctx); 98 | // =============== end of benchmark kernel ================ 99 | reps += 1; 100 | end = chrono::high_resolution_clock::now(); 101 | // ignore the first iteration, it's just for warmup 102 | if(reps == 1) { 103 | start = end; 104 | } 105 | } 106 | cout << "Completed " << reps << " iterations" << endl; 107 | float opcount = 2.0*(float)rows*(float)depth*(float)cols; 108 | float nscount = chrono::duration_cast(end-start).count() / (float)reps; 109 | float perf = opcount / nscount; // billion bit operations per second 110 | cout << "Time for a single " << bench_name << ": " << nscount << " nanoseconds" << endl; 111 | cout << "Performance for " << bench_name << ": " << perf << " GOPS per second" << endl; 112 | 113 | deallocGEMMContext(ctx); 114 | delete [] res; 115 | } 116 | } 117 | 118 | void benchmark_import_interactive() { 119 | string bench_name = "Regular-to-bitserial conversion"; 120 | while(1) { 121 | int rows, cols, nbits; 122 | float secs; 123 | cout << "Enter rows cols nbits, 0 for next benchmark, -1 to exit " << endl; 124 | cin >> rows; 125 | if(rows == 0) { 126 | break; 127 | } else if (rows < 0) { 128 | exit(0); 129 | } 130 | cin >> cols >> nbits; 131 | int nthres = (1 << nbits) - 1; 132 | cout << "Enter 0 for regular import, 1 for thresholding import: " << endl; 133 | int use_thres; 134 | cin >> use_thres; 135 | if(use_thres) { 136 | cout << "Benchmark will use " << nthres << " thresholds" << endl; 137 | } 138 | cout << "Enter 0 for no transpose, 1 for tranposed import: " << endl; 139 | int do_transpose; 140 | cin >> do_transpose; 141 | cout << "Enter number of seconds to benchmark: " << endl; 142 | cin >> secs; 143 | BitSerialMatrix bsm = BitSerialMatrix::alloc(nbits, rows, cols, false); 144 | uint8_t * rand_mat = new uint8_t[rows*cols]; 145 | uint8_t * rand_thres = new uint8_t[rows*nthres]; 146 | generateRandomVector(nbits, rows*cols, rand_mat); 147 | generateRandomVector(nbits, rows*nthres, rand_thres); 148 | unsigned int reps = 0; 149 | auto start = chrono::high_resolution_clock::now(); 150 | auto end = chrono::high_resolution_clock::now(); 151 | while (chrono::duration_cast(end-start).count() < secs) { 152 | // =============== start of benchmark kernel ============= 153 | if(use_thres) { 154 | bsm.importRegularAndQuantize(rand_mat, rand_thres, nthres, (bool) do_transpose); 155 | } else { 156 | bsm.importRegular(rand_mat, (bool) do_transpose); 157 | } 158 | // =============== end of benchmark kernel ================ 159 | reps += 1; 160 | end = chrono::high_resolution_clock::now(); 161 | // ignore the first iteration, it's just for warmup 162 | if(reps == 1) { 163 | start = end; 164 | } 165 | } 166 | float mscount = chrono::duration_cast(end-start).count() / (float)reps; 167 | cout << "Completed " << reps << " iterations, " << mscount << " ms per iteration" << endl; 168 | BitSerialMatrix::dealloc(bsm); 169 | delete [] rand_mat; 170 | delete [] rand_thres; 171 | } 172 | } 173 | 174 | void benchmark_caffenet(float secs) { 175 | string bench_name = "CaffeNet matrices"; 176 | const int caffenet_gemm_sizes[] = { 177 | 96, 363, 3025, 178 | 256, 2400, 729, 179 | 384, 2304, 169, 180 | 384, 3456, 169, 181 | 256, 3456, 169, 182 | 4096, 9216, 1, 183 | 4096, 4096, 1, 184 | 1000, 4096, 1 185 | }; 186 | double opcount = 0; 187 | size_t wbits = 2; 188 | size_t abits = 2; 189 | const std::size_t num_caffenet_gemms = 190 | sizeof(caffenet_gemm_sizes) / (3 * sizeof(caffenet_gemm_sizes[0])); 191 | // prepare workload 192 | vector caffenet_gemms; 193 | for (std::size_t i = 0; i < num_caffenet_gemms; i++) { 194 | size_t rows = caffenet_gemm_sizes[3 * i + 0]; 195 | size_t depth = caffenet_gemm_sizes[3 * i + 1]; 196 | size_t cols = caffenet_gemm_sizes[3 * i + 2]; 197 | opcount += 2*rows*depth*cols; 198 | uint8_t * rnd_matA = new uint8_t[rows*depth]; 199 | uint8_t * rnd_matB = new uint8_t[depth*cols]; 200 | generateRandomVector(wbits, rows*depth, rnd_matA); 201 | generateRandomVector(abits, depth*cols, rnd_matB); 202 | GEMMContext g = allocGEMMContext(rows, depth, cols, wbits, abits, false, false); 203 | g.lhs.importRegular(rnd_matA); 204 | g.rhs.importRegular(rnd_matB); 205 | caffenet_gemms.push_back(g); 206 | delete [] rnd_matA; 207 | delete [] rnd_matB; 208 | } 209 | 210 | cout << "======================================================================" << endl; 211 | cout << bench_name << " for " << secs << " seconds..." << endl; 212 | unsigned int reps = 0; 213 | auto start = chrono::high_resolution_clock::now(); 214 | auto end = chrono::high_resolution_clock::now(); 215 | while (chrono::duration_cast(end-start).count() < secs) { 216 | // =============== start of benchmark kernel ============= 217 | for(size_t i = 0; i < num_caffenet_gemms; i++) { 218 | gemmBitSerial(caffenet_gemms[i]); 219 | } 220 | // =============== end of benchmark kernel ================ 221 | reps += 1; 222 | end = chrono::high_resolution_clock::now(); 223 | // ignore the first iteration, it's just for warmup 224 | if(reps == 1) { 225 | start = end; 226 | } 227 | } 228 | cout << "Completed " << reps << " iterations" << endl; 229 | double nscount = chrono::duration_cast(end-start).count() / (double)reps; 230 | double perf = opcount / (nscount); // billion bit operations per second 231 | cout << "Time for a single " << bench_name << ": " << nscount << " nanoseconds" << endl; 232 | cout << "Performance for " << bench_name << ": " << perf << " GOPS per second" << endl; 233 | 234 | for (std::size_t i = 0; i < num_caffenet_gemms; i++) { 235 | deallocGEMMContext(caffenet_gemms[i]); 236 | } 237 | } 238 | 239 | int main(int argc, char const *argv[]) { 240 | benchmark_gemm_interactive(); 241 | benchmark_import_interactive(); 242 | benchmark_caffenet(20); 243 | 244 | vector dims {256, 512, 1024, 2048, 4096, 8192, 16384}; 245 | for(auto &d: dims) { 246 | benchmark_unrolledpopcount(d, 5); 247 | } 248 | 249 | return 0; 250 | } 251 | -------------------------------------------------------------------------------- /test/test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "gemmbitserial.hpp" 8 | #include "mnistdata.h" 9 | #include "testhelpers.hpp" 10 | 11 | using namespace std; 12 | using namespace gemmbitserial; 13 | 14 | #define VERBOSE_TEST(x) ; 15 | //#define VERBOSE_TEST(x) x 16 | 17 | void printBitSerialMatrix(BitSerialMatrix * bsm) { 18 | std::cout << "BitSerialMatrix with bits " << bsm->nbits << " rows " << bsm->nrows << " cols " << bsm->ncols << std::endl; 19 | for(int b = 0; b < bsm->nbits; b++) { 20 | std::cout << "bit " << b << ":" << std::endl; 21 | for(int r = 0; r < bsm->nrows; r++) { 22 | for(int c = 0; c < bsm->ncols; c++) { 23 | std::cout << (bsm->get(b,r,c) ? 1 : 0) << " "; 24 | } 25 | std::cout << std::endl << std::endl; 26 | } 27 | } 28 | } 29 | 30 | bool test_rowwise_sum() { 31 | vector param_bits {1, 2, 3, 4}; 32 | vector param_dims {4, 16, 17, 32, 77, 100, 1023}; 33 | vector param_signed {1, 0}; 34 | unsigned int numConfigs = 0, ok = 0, nok = 0; 35 | for(auto & b: param_bits) { 36 | for(auto & d: param_dims) { 37 | for(auto & sgnd: param_signed) { 38 | bool isSigned = (bool) sgnd; 39 | int8_t * rnd_mat = new int8_t[d*d]; 40 | int32_t * res_ret = new int32_t[d]; 41 | int32_t * res_golden = new int32_t[d]; 42 | generateRandomVector(b, d*d, rnd_mat, isSigned); 43 | // TODO add aligned version of BitSerialMatrix::alloc 44 | GEMMContext ctx = allocGEMMContext( 45 | d, d, 1, b, 1, isSigned, false 46 | ); 47 | BitSerialMatrix bsm = ctx.lhs; 48 | bsm.importRegular(rnd_mat); 49 | sumRows(bsm, res_ret); 50 | naive_sum_rows(rnd_mat, res_golden, d, d); 51 | int res = memcmp(res_ret, res_golden, d); 52 | if(res == 0) { 53 | ok++; 54 | } else { 55 | nok++; 56 | } 57 | /*printmatrix(rnd_mat, d, d); 58 | printmatrix(res_golden, d, 1); 59 | printmatrix(res_ret, d, 1);*/ 60 | deallocGEMMContext(ctx); 61 | delete [] rnd_mat; 62 | delete [] res_golden; 63 | delete [] res_ret; 64 | numConfigs++; 65 | VERBOSE_TEST(cout << "Bits = " << b << " dim = " << d << " result = " << res << endl); 66 | } 67 | } 68 | } 69 | cout << "Row-wise sum tests: " << ok << " OK, " << nok << " NOK" << endl; 70 | return ok == numConfigs; 71 | } 72 | 73 | bool test_conversions() { 74 | vector param_bits {1, 2, 3, 7}; 75 | vector param_dims {4, 16, 17, 32, 77, 100, 1023}; 76 | vector param_signed {1, 0}; 77 | unsigned int numConfigs = 0, ok = 0, nok = 0; 78 | 79 | for(auto & b: param_bits) { 80 | for(auto & d: param_dims) { 81 | for(auto & sgnd: param_signed) { 82 | int8_t * res_chk = new int8_t[d*d]; 83 | int8_t * rnd_vec = new int8_t[d*d]; 84 | assert(res_chk != 0 && rnd_vec != 0); 85 | generateRandomVector(b, d*d, rnd_vec, (bool) sgnd); 86 | 87 | BitSerialMatrix bsm = BitSerialMatrix::alloc(b, d, d, (bool) sgnd); 88 | bsm.importRegular(rnd_vec); 89 | bsm.exportRegular(res_chk); 90 | //printmatrix(rnd_vec, d, d); 91 | //printmatrix(res_chk, d, d); 92 | int res = memcmp(rnd_vec, res_chk, d); 93 | if(res == 0) { 94 | ok++; 95 | } else { 96 | nok++; 97 | } 98 | delete [] rnd_vec; 99 | delete [] res_chk; 100 | BitSerialMatrix::dealloc(bsm); 101 | numConfigs++; 102 | VERBOSE_TEST(cout << "Bits = " << b << " dim = " << d << " result = " << res << endl); 103 | } 104 | } 105 | } 106 | cout << "Conversion tests: " << ok << " OK, " << nok << " NOK" << endl; 107 | return ok == numConfigs; 108 | } 109 | 110 | bool test_matrix_matrix() { 111 | vector param_bits {2, 3, 4}; 112 | vector param_dims {3, 5, 7, 16, 17, 18, 30, 31, 32, 100, 177, 256}; 113 | vector do_matrix_vector {0, 1}; 114 | 115 | unsigned int numConfigs = 0, ok = 0, nok = 0; 116 | for(auto & b: param_bits) { 117 | for(auto & d: param_dims) { 118 | for(auto & mv: do_matrix_vector) { 119 | uint8_t * rnd_mat_a = new uint8_t[d*d*2]; 120 | uint8_t * rnd_mat_b = new uint8_t[2*d*(mv ? 1 : d*3)]; 121 | int32_t * res_mat_golden = new int32_t[d*(mv ? 1 : d*3)]; 122 | generateRandomVector(b, d*d*2, rnd_mat_a); 123 | generateRandomVector(b, 2*d*(mv ? 1 : d*3), rnd_mat_b); 124 | naive_int_gemm(rnd_mat_a, rnd_mat_b, res_mat_golden, d, 2*d, (mv ? 1 : d*3)); 125 | GEMMContext ctx = allocGEMMContext(d, 2*d, (mv ? 1 : d*3), b, b, false, false); 126 | ctx.lhs.importRegular(rnd_mat_a); 127 | ctx.rhs.importRegular(rnd_mat_b); 128 | 129 | gemmBitSerial(ctx); 130 | //ctx.printSummary(); 131 | //printmatrix(rnd_mat_a, d, d*2); 132 | //printmatrix(rnd_mat_b, d*3, d*2); 133 | //printmatrix(res_mat_golden, d*3, d); 134 | //printmatrix(ctx.res, d*3, d); 135 | 136 | int rbytes = d*(mv ? 1 : d*3)*sizeof(int32_t); 137 | int res = memcmp(ctx.res, res_mat_golden, rbytes); 138 | if(res == 0) { 139 | ok++; 140 | } else { 141 | nok++; 142 | //printmatrixdiff(res_mat, res_mat_golden, 3*d, d); 143 | } 144 | delete [] rnd_mat_a; 145 | delete [] rnd_mat_b; 146 | delete [] res_mat_golden; 147 | deallocGEMMContext(ctx); 148 | numConfigs++; 149 | VERBOSE_TEST(cout << "Bits = " << b << " dim = " << d << " result = " << res << endl); 150 | } 151 | } 152 | } 153 | cout << "Matrix matrix multiplication tests: " << ok << " OK, " << nok << " NOK" << endl; 154 | return ok == numConfigs; 155 | } 156 | 157 | bool test_mnist() { 158 | // test bit serial gemm using real-life matrix data from a MNIST neural net 159 | GEMMContext ctx = allocGEMMContext( 160 | MNIST_OUT, MNIST_IN, 1, MNIST_WBITS, MNIST_ABITS, MNIST_WSIGN, MNIST_ASIGN 161 | ); 162 | ctx.lhs.importRegular(mnist_weights); 163 | ctx.rhs.importRegular(mnist_in); 164 | gemmBitSerial(ctx); 165 | int res = memcmp(ctx.res, mnist_res_golden, MNIST_OUT*sizeof(int32_t)); 166 | cout << "MNIST matrix-vector using bipolar times regular: " << (res == 0 ? "OK" : "NOK") << endl; 167 | if(res != 0) { 168 | printmatrixdiff(ctx.res, mnist_res_golden, 1, MNIST_OUT); 169 | } 170 | deallocGEMMContext(ctx); 171 | return res == 0; 172 | } 173 | 174 | bool test_bipolar_times_regular() { 175 | vector param_regularmatrix_bits {2, 3, 4}; 176 | vector param_dims {3, 5, 7, 16, 17, 18, 30, 31, 32, 100, 177, 256}; 177 | vector param_signed {1, 0}; 178 | vector param_switch_lhsrhs {0, 1}; 179 | 180 | unsigned int numConfigs = 0, ok = 0, nok = 0; 181 | // TODO when bipolar times bipolar is covered, merge into matrix matrix 182 | for(auto & sw_lhsrhs: param_switch_lhsrhs) { 183 | for(auto & rhs_bits: param_regularmatrix_bits) { 184 | for(auto & d: param_dims) { 185 | for(auto & sgnd: param_signed) { 186 | const size_t lhs_bits = 1; 187 | const bool lhs_sign = true; 188 | const bool rhs_sign = (bool) sgnd; 189 | int8_t * bipolar_mat = new int8_t[d*d]; 190 | int8_t * regular_mat = new int8_t[d*d]; 191 | int32_t * res_golden = new int32_t[d*d]; 192 | int32_t * res_chk = new int32_t[d*d]; 193 | generateRandomVector_Bipolar(d*d, bipolar_mat); 194 | generateRandomVector(rhs_bits, d*d, regular_mat, rhs_sign); 195 | GEMMContext ctx; 196 | if(sw_lhsrhs) { 197 | ctx = allocGEMMContext( 198 | d, d, d, rhs_bits, lhs_bits, rhs_sign, lhs_sign 199 | ); 200 | ctx.rhs.importRegular(bipolar_mat); 201 | ctx.lhs.importRegular(regular_mat); 202 | naive_int_gemm(regular_mat, bipolar_mat, res_golden, d, d, d); 203 | } else { 204 | ctx = allocGEMMContext( 205 | d, d, d, lhs_bits, rhs_bits, lhs_sign, rhs_sign 206 | ); 207 | ctx.lhs.importRegular(bipolar_mat); 208 | ctx.rhs.importRegular(regular_mat); 209 | naive_int_gemm(bipolar_mat, regular_mat, res_golden, d, d, d); 210 | } 211 | gemmBitSerial(ctx); 212 | //printmatrix(bipolar_mat, d, d); 213 | //printmatrix(regular_mat, d, d); 214 | //printmatrix(res_golden, d, d); 215 | //printmatrix(ctx.res, d, d); 216 | int res = memcmp(res_golden, ctx.res, sizeof(int32_t)*d*d); 217 | if(res == 0) { 218 | ok++; 219 | } else { 220 | nok++; 221 | } 222 | numConfigs++; 223 | delete [] bipolar_mat; 224 | delete [] regular_mat; 225 | delete [] res_golden; 226 | delete [] res_chk; 227 | deallocGEMMContext(ctx); 228 | VERBOSE_TEST(cout << "Bits = " << rhs_bits << " dim = " << d << " result = " << res << endl); 229 | } 230 | } 231 | } 232 | } 233 | cout << "Bipolar times regular tests: " << ok << " OK, " << nok << " NOK" << endl; 234 | return ok == numConfigs; 235 | } 236 | 237 | bool test_bipolar_times_bipolar() { 238 | vector param_dims {3, 5, 7, 16, 17, 18, 30, 31, 32, 100, 177, 256}; 239 | vector do_matrix_vector {0, 1}; 240 | unsigned int numConfigs = 0, ok = 0, nok = 0; 241 | 242 | for(auto & d: param_dims) { 243 | for(auto & mv: do_matrix_vector) { 244 | int8_t * lhs_mat = new int8_t[d*d]; 245 | int8_t * rhs_mat = new int8_t[mv ? d : d*d]; 246 | int32_t * res_golden = new int32_t[mv ? d : d*d]; 247 | int32_t * res_chk = new int32_t[mv ? d : d*d]; 248 | GEMMContext ctx = allocGEMMContext( 249 | d, d, mv ? 1 : d, 1, 1, true, true 250 | ); 251 | generateRandomVector_Bipolar(d*d, lhs_mat); 252 | generateRandomVector_Bipolar(mv ? d : d*d, rhs_mat); 253 | ctx.lhs.importRegular(lhs_mat); 254 | ctx.rhs.importRegular(rhs_mat); 255 | gemmBitSerial(ctx); 256 | naive_int_gemm(lhs_mat, rhs_mat, res_golden, d, d, mv ? 1 : d); 257 | //printmatrix(lhs_mat, d, d); 258 | //printmatrix(rhs_mat, d, d); 259 | //printmatrix(res_golden, d, d); 260 | //printmatrix(ctx.res, d, d); 261 | int res = memcmp(res_golden, ctx.res, sizeof(int32_t)*d*(mv ? 1 : d)); 262 | if(res == 0) { 263 | ok++; 264 | } else { 265 | nok++; 266 | } 267 | numConfigs++; 268 | delete [] lhs_mat; 269 | delete [] rhs_mat; 270 | delete [] res_golden; 271 | delete [] res_chk; 272 | deallocGEMMContext(ctx); 273 | } 274 | } 275 | cout << "Bipolar times bipolar tests: " << ok << " OK, " << nok << " NOK" << endl; 276 | return ok == numConfigs; 277 | } 278 | 279 | int main(int argc, char const *argv[]) { 280 | srand(time(NULL)); 281 | bool all_ok = true; 282 | all_ok &= test_conversions(); 283 | all_ok &= test_rowwise_sum(); 284 | all_ok &= test_mnist(); 285 | all_ok &= test_matrix_matrix(); 286 | all_ok &= test_bipolar_times_regular(); 287 | all_ok &= test_bipolar_times_bipolar(); 288 | 289 | if(all_ok) { 290 | cout << "All tests completed successfully" << endl; 291 | } else { 292 | cout << "Some tests failed" << endl; 293 | } 294 | return 0; 295 | } 296 | -------------------------------------------------------------------------------- /arch-generic.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // generic (non-architecture-specific) implementations of gemmBitserial 3 | // and other related functions 4 | 5 | // Compute the row-wise sum of a bit-serial matrix 6 | static void sumRows_generic(BitSerialMatrix m, int32_t * row_sums) { 7 | const uint64_t nc = m.wordsPerRow(); 8 | 9 | for(uint64_t r = 0; r < m.nrows; r++) { 10 | int32_t row_acc = 0; 11 | if(m.isBipolar()) { 12 | uint64_t * rowptr = m.rowptr(0, r); 13 | for(uint64_t c = 0; c < nc; c++) { 14 | row_acc += __builtin_popcountll(rowptr[c]); 15 | } 16 | // account for -1s in the sum. how does this work? let p be the number of 17 | // +1 bits, and n be the number of -1 bits. we know that there are only 18 | // p+n bits in total, and we want to compute the sum as p-n. rewriting 19 | // -n in terms of the number of columns (bits), we get: 20 | row_sums[r] = 2 * row_acc - m.ncols; 21 | } else { 22 | for(uint64_t b = 0; b < m.nbits; b++) { 23 | uint64_t * rowptr = m.rowptr(b, r); 24 | int32_t bit_acc = 0; 25 | for(uint64_t c = 0; c < nc; c++) { 26 | bit_acc += __builtin_popcountll(rowptr[c]); 27 | } 28 | bit_acc = bit_acc << b; 29 | if(m.issigned && b == m.nbits - 1) { 30 | bit_acc = -bit_acc; 31 | } 32 | row_acc += bit_acc; 33 | } 34 | row_sums[r] = row_acc; 35 | } 36 | } 37 | } 38 | 39 | static void prepareAccumulators_generic(GEMMContext ctx) { 40 | // when bits = 1 and signed = true, we assume a matrix is bipolar, not using 41 | //{-1, 0} but instead {-1, +1} values. 42 | bool lhsBipolar = (ctx.lhs.nbits == 1) && ctx.lhs.issigned; 43 | bool rhsBipolar = (ctx.rhs.nbits == 1) && ctx.rhs.issigned; 44 | 45 | if(lhsBipolar ^ rhsBipolar) { 46 | BitSerialMatrix bipolarM = lhsBipolar ? ctx.lhs : ctx.rhs; 47 | BitSerialMatrix regularM = lhsBipolar ? ctx.rhs : ctx.lhs; 48 | // if only one matrix is bipolar, we'll need to do something special. 49 | // despite the bipolar matrix, we'll compute the result using {0,1} 50 | // (regular unsigned 1-bit) matrices as follows: 51 | // let x be a column vector, W a bipolar matrix, and B a binary matrix which 52 | // is identical to W except all -1s are represented as 0. 53 | // note that each element We in W can be rewritten as 2*Be-1 54 | // by initializing the result vector to the negative of sum of all elements 55 | // in x, we get the same result using B instead of W. 56 | // compute columnwise sum of the regular matrix with bit serial 57 | // TODO should this buffer be part of the GEMMContext? 58 | int32_t * rowwise_sum = new int32_t[regularM.nrows]; 59 | sumRows_generic(regularM, rowwise_sum); 60 | // initialize result matrix accumulators from sum 61 | for(auto res_row = 0; res_row < ctx.rhs.nrows; res_row++) { 62 | for(auto res_col = 0; res_col < ctx.lhs.nrows; res_col++) { 63 | if(lhsBipolar) { 64 | ctx.res[res_row * ctx.lhs.nrows + res_col] = -rowwise_sum[res_row]; 65 | 66 | } else { 67 | ctx.res[res_row * ctx.lhs.nrows + res_col] = -rowwise_sum[res_col]; 68 | } 69 | } 70 | } 71 | delete [] rowwise_sum; 72 | } else { 73 | // just initialize all result matrix accumulators to zero 74 | memset(ctx.res, 0, ctx.lhs.nrows*ctx.rhs.nrows*sizeof(int32_t)); 75 | } 76 | } 77 | 78 | static GEMMContext allocGEMMContext_generic( 79 | uint64_t lhsRows, uint64_t depth, uint64_t rhsRows, 80 | uint64_t lhsBits, uint64_t rhsBits, 81 | bool lhsSigned, bool rhsSigned 82 | ) { 83 | const uint64_t regblock_lhs = 2; 84 | const uint64_t regblock_d = 1; 85 | const uint64_t regblock_rhs = 2; 86 | const uint64_t cacheBits = 32*1024*8; 87 | 88 | if(rhsRows == 1) { 89 | // matrix-vector only needs depth alignment 90 | return allocGEMMContext_base( 91 | lhsRows, depth, rhsRows, lhsBits, rhsBits, lhsSigned, rhsSigned, 92 | 1, 4, 1, cacheBits 93 | ); 94 | } else { 95 | return allocGEMMContext_base( 96 | lhsRows, depth, rhsRows, lhsBits, rhsBits, lhsSigned, rhsSigned, 97 | regblock_lhs, regblock_d, regblock_rhs, cacheBits 98 | ); 99 | } 100 | }; 101 | 102 | 103 | /* Multiply a lhs_block x rhs_block chunk of the given matrices, starting at 104 | (bA, bBT) using 2x1x2 register tiling. For internal use. 105 | */ 106 | inline void gemmBinary_generic_chunk_tile2x1x2( 107 | uint64_t * A, uint64_t * BT, int32_t * CT, 108 | int32_t alpha, 109 | uint64_t rowsA, uint64_t depth_words, uint64_t rowsBT, 110 | uint64_t bA, uint64_t bBT, 111 | uint64_t lhs_block, uint64_t rhs_block, 112 | uint64_t rowsA_orig, uint64_t rowsBT_orig) { 113 | const uint64_t Atile = 2, DepthTile = 1, BTtile = 2; 114 | const size_t num_acc = Atile*BTtile; 115 | 116 | for(uint64_t rBT = bBT; rBT < bBT + rhs_block; rBT += BTtile) { 117 | uint64_t * BTptr = &BT[rBT * depth_words]; 118 | for(uint64_t rA = bA; rA < bA + lhs_block; rA += Atile) { 119 | uint64_t * Aptr = &A[rA * depth_words]; 120 | int32_t acc[num_acc] = {0}; 121 | for(uint64_t d = 0; d < depth_words; d += DepthTile) { 122 | const uint64_t a0 = Aptr[d], a1 = Aptr[d + depth_words]; 123 | const uint64_t b0 = BTptr[d], b1 = BTptr[d + depth_words]; 124 | acc[0] += __builtin_popcountll(a0 & b0); 125 | acc[1] += __builtin_popcountll(a0 & b1); 126 | acc[2] += __builtin_popcountll(a1 & b0); 127 | acc[3] += __builtin_popcountll(a1 & b1); 128 | } 129 | for(uint64_t at = 0; at < Atile; at++) { 130 | for(uint64_t bt = 0; bt < BTtile; bt++) { 131 | if(((rBT + bt) < rowsBT_orig) && ((rA + at) < rowsA_orig)) { 132 | CT[(rBT + bt) * rowsA_orig + (rA + at)] += acc[at * BTtile + bt] * alpha; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | } 139 | 140 | /* CT = A * BT using cache blocking and 2x1x2 register blocking where possible. 141 | For internal use. 142 | */ 143 | static void gemmBinary_generic_L1_tile2x1x2( 144 | uint64_t * A, uint64_t * BT, int32_t * CT, int32_t alpha, 145 | uint64_t rowsA, uint64_t depth_words, uint64_t rowsBT, 146 | uint64_t rowsA_orig, uint64_t rowsBT_orig, 147 | uint64_t lhsBlock, uint64_t rhsBlock 148 | ) { 149 | const uint64_t Atile = 2, DepthTile = 1, BTtile = 2; 150 | assert(rowsBT % rhsBlock == 0); 151 | assert(rowsA % lhsBlock == 0); 152 | assert(lhsBlock % Atile == 0); 153 | assert(rhsBlock % BTtile == 0); 154 | 155 | for(uint64_t bBT = 0; bBT < rowsBT; bBT += rhsBlock) { 156 | for(uint64_t bA = 0; bA < rowsA; bA += lhsBlock) { 157 | gemmBinary_generic_chunk_tile2x1x2( 158 | A, BT, CT, alpha, rowsA, depth_words, rowsBT, bA, bBT, 159 | lhsBlock, rhsBlock, rowsA_orig, rowsBT_orig 160 | ); 161 | } 162 | } 163 | } 164 | 165 | /* Bit-serial GEMM via a series of calls to gemmBinary. 166 | Note that rhs must be given in transposed form, and the result is also 167 | produced transposed. 168 | */ 169 | static void gemmBitSerial_generic_usingBinary(GEMMContext ctx) { 170 | // ensure that matrix shapes are compatible 171 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 172 | const uint64_t lhsbits = ctx.lhs.nbits; 173 | const uint64_t rhsbits = ctx.rhs.nbits; 174 | 175 | prepareAccumulators_generic(ctx); 176 | // call binary GEMM for each bit position 177 | // note that bipolars don't count as negative, we do those with {0, 1} as a 178 | // special case 179 | for(uint64_t lbit = 0; lbit < lhsbits; lbit++) { 180 | bool neg_lhs = ctx.lhs.issigned && !ctx.lhs.isBipolar() && (lbit == lhsbits-1); 181 | for(uint64_t rbit = 0; rbit < rhsbits; rbit++) { 182 | bool neg_rhs = ctx.rhs.issigned && !ctx.rhs.isBipolar() && (rbit == rhsbits-1); 183 | bool neg = neg_rhs ^ neg_lhs; 184 | int32_t alpha = neg ? -(1 << (lbit+rbit)) : (1 << (lbit+rbit)); 185 | alpha = ctx.isBipolarTimesRegular() ? 2*alpha : alpha; 186 | gemmBinary_generic_L1_tile2x1x2( 187 | ctx.lhs.bitplaneptr(lbit), ctx.rhs.bitplaneptr(rbit), ctx.res, alpha, 188 | ctx.lhs.nrows_a, ctx.lhs.wordsPerRow(), ctx.rhs.nrows_a, 189 | ctx.lhs.nrows, ctx.rhs.nrows, ctx.lhsBlock, ctx.rhsBlock 190 | ); 191 | } 192 | } 193 | } 194 | 195 | 196 | /* Standalone bit-serial GEMM. Note that rhs must be given in transposed 197 | form, and the result is also produced transposed. 198 | */ 199 | static void gemmBitSerial_generic_naive(GEMMContext ctx) { 200 | // ensure that matrix shapes are compatible 201 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 202 | const uint64_t lhsbits = ctx.lhs.nbits; 203 | const uint64_t rhsbits = ctx.rhs.nbits; 204 | const uint64_t out_rows = ctx.lhs.nrows; 205 | const uint64_t out_cols = ctx.rhs.nrows; 206 | const uint64_t depth = ctx.lhs.wordsPerRow(); 207 | prepareAccumulators_generic(ctx); 208 | for(uint64_t i = 0; i < out_cols; i++) { 209 | for(uint64_t j = 0; j < out_rows; j++) { 210 | int32_t rowres = 0; 211 | for(uint64_t lbit = 0; lbit < lhsbits; lbit++) { 212 | bool neg_lhs = ctx.lhs.issigned && !ctx.lhs.isBipolar() && (lbit == lhsbits-1); 213 | for(uint64_t rbit = 0; rbit < rhsbits; rbit++) { 214 | bool neg_rhs = ctx.rhs.issigned && !ctx.rhs.isBipolar() && (rbit == rhsbits-1); 215 | uint64_t * ldata = ctx.lhs.rowptr(lbit, j); 216 | uint64_t * rdata = ctx.rhs.rowptr(rbit, i); 217 | uint64_t andcard = 0; 218 | // AND-popcount-accumulate over row pair 219 | for(uint64_t k = 0; k < depth; k++) { 220 | andcard += __builtin_popcountll(ldata[k] & rdata[k]); 221 | } 222 | // scale 223 | uint64_t bpreg_scale = ctx.isBipolarTimesRegular() ? 1 : 0; 224 | andcard = andcard << (lbit + rbit + bpreg_scale); 225 | // negate if needed 226 | rowres += (neg_lhs ^ neg_rhs) ? -andcard : andcard; 227 | } 228 | } 229 | ctx.res[i * ctx.lhs.nrows + j] += rowres; 230 | } 231 | } 232 | } 233 | 234 | // Special case: bipolar times bipolar matrix multiplication. These use 235 | // XNOR-popcount instead of AND-popcount, and also need an additional correction 236 | // step to account for zeroes being treated as -1 bits 237 | 238 | // naive implementation for bipolar GEMM 239 | static void gemmBipolar_generic_naive(GEMMContext ctx) { 240 | // ensure that matrix shapes are compatible 241 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 242 | assert(ctx.lhs.isBipolar() && ctx.rhs.isBipolar()); 243 | const uint64_t out_rows = ctx.lhs.nrows; 244 | const uint64_t out_cols = ctx.rhs.nrows; 245 | const uint64_t depth = ctx.lhs.wordsPerRow(); 246 | prepareAccumulators_generic(ctx); 247 | for(uint64_t i = 0; i < out_cols; i++) { 248 | for(uint64_t j = 0; j < out_rows; j++) { 249 | int32_t rowres = 0; 250 | uint64_t * ldata = ctx.lhs.rowptr(0, j); 251 | uint64_t * rdata = ctx.rhs.rowptr(0, i); 252 | // XNOR-popcount-accumulate over row pair. note that we do XOR-popcount 253 | // to save one instruction (no need to invert the XOR result). this is 254 | // accounted for in the correction afterwards. 255 | for(uint64_t k = 0; k < depth; k++) { 256 | rowres += __builtin_popcountll(ldata[k] ^ rdata[k]); 257 | } 258 | // correction for sum of 1 and -1 bits 259 | ctx.res[i * ctx.lhs.nrows + j] += -2 * rowres + ctx.lhs.ncols; 260 | } 261 | } 262 | } 263 | 264 | /* Standalone bit-serial GEMV (matrix-vector). Note that rhs must be given in transposed 265 | form, and the result is also produced transposed. 266 | */ 267 | static void gemvBitSerial_generic(GEMMContext ctx) { 268 | // ensure that matrix shapes are compatible 269 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 270 | const uint64_t lhsbits = ctx.lhs.nbits; 271 | const uint64_t rhsbits = ctx.rhs.nbits; 272 | const uint64_t out_rows = ctx.lhs.nrows; 273 | const uint64_t depth = ctx.lhs.wordsPerRow(); 274 | uint64_t bpreg_scale = ctx.isBipolarTimesRegular() ? 1 : 0; 275 | prepareAccumulators_generic(ctx); 276 | for(uint64_t j = 0; j < out_rows; j++) { 277 | int32_t rowres = 0; 278 | for(uint64_t lbit = 0; lbit < lhsbits; lbit++) { 279 | bool neg_lhs = ctx.lhs.issigned && !ctx.lhs.isBipolar() && (lbit == lhsbits-1); 280 | for(uint64_t rbit = 0; rbit < rhsbits; rbit++) { 281 | bool neg_rhs = ctx.rhs.issigned && !ctx.rhs.isBipolar() && (rbit == rhsbits-1); 282 | uint64_t * ldata = ctx.lhs.rowptr(lbit, j); 283 | uint64_t * rdata = ctx.rhs.rowptr(rbit, 0); 284 | uint64_t andcard = 0; 285 | // AND-popcount-accumulate over row pair 286 | for(uint64_t k = 0; k < depth; k+=4) { 287 | andcard += __builtin_popcountll(ldata[k] & rdata[k]); 288 | andcard += __builtin_popcountll(ldata[k+1] & rdata[k+1]); 289 | andcard += __builtin_popcountll(ldata[k+2] & rdata[k+2]); 290 | andcard += __builtin_popcountll(ldata[k+3] & rdata[k+3]); 291 | } 292 | // scale 293 | andcard = andcard << (lbit + rbit + bpreg_scale); 294 | // negate if needed 295 | rowres += (neg_lhs ^ neg_rhs) ? -andcard : andcard; 296 | } 297 | } 298 | ctx.res[j] += rowres; 299 | } 300 | } 301 | 302 | // Special case: bipolar times bipolar matrix vector multiplication. These use 303 | // XNOR-popcount instead of AND-popcount, and also need an additional correction 304 | // step to account for zeroes being treated as -1 bits 305 | 306 | static void gemvBipolar_generic(GEMMContext ctx) { 307 | // ensure that matrix shapes are compatible 308 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 309 | assert(ctx.lhs.isBipolar() && ctx.rhs.isBipolar()); 310 | const uint64_t out_rows = ctx.lhs.nrows; 311 | const uint64_t depth = ctx.lhs.wordsPerRow(); 312 | prepareAccumulators_generic(ctx); 313 | for(uint64_t j = 0; j < out_rows; j++) { 314 | int32_t rowres = 0; 315 | uint64_t * ldata = ctx.lhs.rowptr(0, j); 316 | uint64_t * rdata = ctx.rhs.rowptr(0, 0); 317 | // XNOR-popcount-accumulate over row pair. note that we do XOR-popcount 318 | // to save one instruction (no need to invert the XOR result). this is 319 | // accounted for in the correction afterwards. 320 | for(uint64_t k = 0; k < depth; k+=4) { 321 | rowres += __builtin_popcountll(ldata[k] ^ rdata[k]); 322 | rowres += __builtin_popcountll(ldata[k+1] ^ rdata[k+1]); 323 | rowres += __builtin_popcountll(ldata[k+2] ^ rdata[k+2]); 324 | rowres += __builtin_popcountll(ldata[k+3] ^ rdata[k+3]); 325 | } 326 | // correction for sum of 1 and -1 bits 327 | ctx.res[j] += -2 * rowres + ctx.lhs.ncols; 328 | } 329 | } 330 | 331 | static void gemmBitSerial_generic(GEMMContext ctx) { 332 | if(ctx.isMatrixVector()) { 333 | if(ctx.isBipolarTimesBipolar()) { 334 | gemvBipolar_generic(ctx); 335 | } else { 336 | gemvBitSerial_generic(ctx); 337 | } 338 | } else { 339 | if(ctx.isBipolarTimesBipolar()) { 340 | gemmBipolar_generic_naive(ctx); 341 | } else { 342 | gemmBitSerial_generic_usingBinary(ctx); 343 | } 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /arch-neon.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // architecture-specific implementations of primitives for ARM NEON 4 | // the tiling strategies are named as: 5 | // _tile 6 | 7 | static inline uint64_t popcount_neon(uint64_t * rowptr, uint64_t numElems) { 8 | uint64_t ret = 0; 9 | const uint64_t DepthTile = 2; 10 | uint8x16_t acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 11 | uint64x2_t acc2_neon = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); 12 | for(uint64_t c = 0; c < numElems; c += DepthTile) { 13 | uint8x16_t a0 = vld1q_u8((uint8_t *) &rowptr[c]); 14 | acc_neon = vaddq_u8(acc_neon, vcntq_u8(a0)); 15 | if((c & 7L) == 7L) { 16 | // hsum over 8-bit accumulators when end or overflow 17 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 18 | acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 19 | } 20 | } 21 | // move into regular accumulators 22 | uint64_t tmp[2]; 23 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 24 | vst1q_u64(tmp, acc2_neon); 25 | ret = (tmp[0] + tmp[1]); 26 | return ret; 27 | } 28 | 29 | static inline uint64_t xor_popcount_neon(uint64_t * rowptrA, uint64_t * rowptrB, uint64_t numElems) { 30 | uint64_t ret = 0; 31 | const uint64_t DepthTile = 2; 32 | uint8x16_t acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 33 | uint64x2_t acc2_neon = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); 34 | for(uint64_t c = 0; c < numElems; c += DepthTile) { 35 | uint8x16_t a0 = vld1q_u8((uint8_t *) &rowptrA[c]); 36 | uint8x16_t b0 = vld1q_u8((uint8_t *) &rowptrB[c]); 37 | acc_neon = vaddq_u8(acc_neon, vcntq_u8(veorq_u8(a0, b0))); 38 | if((c & 7L) == 7L) { 39 | // hsum over 8-bit accumulators when end or overflow 40 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 41 | acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 42 | } 43 | } 44 | // move into regular accumulators 45 | uint64_t tmp[2]; 46 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 47 | vst1q_u64(tmp, acc2_neon); 48 | ret = (tmp[0] + tmp[1]); 49 | return ret; 50 | } 51 | 52 | static inline uint64_t and_popcount_neon(uint64_t * rowptrA, uint64_t * rowptrB, uint64_t numElems) { 53 | uint64_t ret = 0; 54 | const uint64_t DepthTile = 2; 55 | uint8x16_t acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 56 | uint64x2_t acc2_neon = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); 57 | for(uint64_t c = 0; c < numElems; c += DepthTile) { 58 | uint8x16_t a0 = vld1q_u8((uint8_t *) &rowptrA[c]); 59 | uint8x16_t b0 = vld1q_u8((uint8_t *) &rowptrB[c]); 60 | acc_neon = vaddq_u8(acc_neon, vcntq_u8(vandq_u8(a0, b0))); 61 | if((c & 7L) == 7L) { 62 | // hsum over 8-bit accumulators when end or overflow 63 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 64 | acc_neon = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 65 | } 66 | } 67 | // move into regular accumulators 68 | uint64_t tmp[2]; 69 | acc2_neon = vaddq_u64(acc2_neon, vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon)))); 70 | vst1q_u64(tmp, acc2_neon); 71 | ret = (tmp[0] + tmp[1]); 72 | return ret; 73 | } 74 | 75 | // Compute the row-wise sum of a bit-serial matrix 76 | static void sumRows_neon(BitSerialMatrix m, int32_t * row_sums) { 77 | const uint64_t nc = m.wordsPerRow(); 78 | 79 | for(uint64_t r = 0; r < m.nrows; r++) { 80 | int32_t row_acc = 0; 81 | if(m.isBipolar()) { 82 | uint64_t * rowptr = m.rowptr(0, r); 83 | row_acc += popcount_neon(rowptr, nc); 84 | // account for -1s in the sum. how does this work? let p be the number of 85 | // +1 bits, and n be the number of -1 bits. we know that there are only 86 | // p+n bits in total, and we want to compute the sum as p-n. rewriting 87 | // -n in terms of the number of columns (bits), we get: 88 | row_sums[r] = 2 * row_acc - m.ncols; 89 | } else { 90 | for(uint64_t b = 0; b < m.nbits; b++) { 91 | uint64_t * rowptr = m.rowptr(b, r); 92 | int32_t bit_acc = (int32_t) popcount_neon(rowptr, nc); 93 | // scale for weight and handle sign bit 94 | bit_acc = bit_acc << b; 95 | if(m.issigned && b == m.nbits - 1) { 96 | bit_acc = -bit_acc; 97 | } 98 | row_acc += bit_acc; 99 | } 100 | row_sums[r] = row_acc; 101 | } 102 | } 103 | } 104 | 105 | static void prepareAccumulators_neon(GEMMContext ctx) { 106 | // when bits = 1 and signed = true, we assume a matrix is bipolar, not using 107 | //{-1, 0} but instead {-1, +1} values. 108 | bool lhsBipolar = (ctx.lhs.nbits == 1) && ctx.lhs.issigned; 109 | bool rhsBipolar = (ctx.rhs.nbits == 1) && ctx.rhs.issigned; 110 | 111 | if(lhsBipolar ^ rhsBipolar) { 112 | BitSerialMatrix bipolarM = lhsBipolar ? ctx.lhs : ctx.rhs; 113 | BitSerialMatrix regularM = lhsBipolar ? ctx.rhs : ctx.lhs; 114 | // if only one matrix is bipolar, we'll need to do something special. 115 | // despite the bipolar matrix, we'll compute the result using {0,1} 116 | // (regular unsigned 1-bit) matrices as follows: 117 | // let x be a column vector, W a bipolar matrix, and B a binary matrix which 118 | // is identical to W except all -1s are represented as 0. 119 | // note that each element We in W can be rewritten as 2*Be-1 120 | // by initializing the result vector to the negative of sum of all elements 121 | // in x, we get the same result using B instead of W. 122 | // compute columnwise sum of the regular matrix with bit serial 123 | // TODO should this buffer be part of the GEMMContext? 124 | int32_t * rowwise_sum = new int32_t[regularM.nrows]; 125 | sumRows_neon(regularM, rowwise_sum); 126 | // initialize result matrix accumulators from sum 127 | for(auto res_row = 0; res_row < ctx.rhs.nrows; res_row++) { 128 | for(auto res_col = 0; res_col < ctx.lhs.nrows; res_col++) { 129 | if(lhsBipolar) { 130 | ctx.res[res_row * ctx.lhs.nrows + res_col] = -rowwise_sum[res_row]; 131 | 132 | } else { 133 | ctx.res[res_row * ctx.lhs.nrows + res_col] = -rowwise_sum[res_col]; 134 | } 135 | } 136 | } 137 | delete [] rowwise_sum; 138 | } else { 139 | // just initialize all result matrix accumulators to zero 140 | memset(ctx.res, 0, ctx.lhs.nrows*ctx.rhs.nrows*sizeof(int32_t)); 141 | } 142 | } 143 | 144 | static GEMMContext allocGEMMContext_neon( 145 | uint64_t lhsRows, uint64_t depth, uint64_t rhsRows, 146 | uint64_t lhsBits, uint64_t rhsBits, 147 | bool lhsSigned, bool rhsSigned 148 | ) { 149 | const uint64_t regblock_lhs = 4; 150 | const uint64_t regblock_d = 2; 151 | const uint64_t regblock_rhs = 2; 152 | const uint64_t cacheBits = 16*1024*8; 153 | 154 | if(rhsRows == 1) { 155 | // matrix-vector only needs depth alignment 156 | return allocGEMMContext_base( 157 | lhsRows, depth, rhsRows, lhsBits, rhsBits, lhsSigned, rhsSigned, 158 | 1, regblock_d, 1, cacheBits 159 | ); 160 | } else { 161 | return allocGEMMContext_base( 162 | lhsRows, depth, rhsRows, lhsBits, rhsBits, lhsSigned, rhsSigned, 163 | regblock_lhs, regblock_d, regblock_rhs, cacheBits 164 | ); 165 | } 166 | }; 167 | 168 | /* CT = A * BT using cache blocking and 2x1x2 register blocking where possible. 169 | For internal use. 170 | */ 171 | static void gemmBinary_neon_L1_tile4x2x2( 172 | uint64_t * A, uint64_t * BT, int32_t * CT, int32_t alpha, 173 | uint64_t rowsA, uint64_t depth_words, uint64_t rowsBT, 174 | uint64_t rowsA_orig, uint64_t rowsBT_orig, 175 | uint64_t lhsBlock, uint64_t rhsBlock) { 176 | const uint64_t Atile = 4, DepthTile = 2, BTtile = 2; 177 | assert(rowsBT % rhsBlock == 0); 178 | assert(rowsA % lhsBlock == 0); 179 | assert(lhsBlock % Atile == 0); 180 | assert(rhsBlock % BTtile == 0); 181 | 182 | for(uint64_t bBT = 0; bBT < rowsBT; bBT += rhsBlock) { 183 | for(uint64_t bA = 0; bA < rowsA; bA += lhsBlock) { 184 | const size_t num_acc = Atile*BTtile; 185 | // start of cache block 186 | for(uint64_t rBT = bBT; rBT < bBT + rhsBlock; rBT += BTtile) { 187 | uint64_t * BTptr = &BT[rBT * depth_words]; 188 | for(uint64_t rA = bA; rA < bA + lhsBlock; rA += Atile) { 189 | uint64_t * Aptr = &A[rA * depth_words]; 190 | uint64_t acc[num_acc] = {0}; 191 | uint8x16_t acc_neon[num_acc]; 192 | // TODO could keep things in 16-bit accumulators for perf. 193 | uint64x2_t acc2_neon[num_acc]; 194 | // initialize accumulators to zero 195 | for(size_t init = 0; init < num_acc; init++) { 196 | acc_neon[init] = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 197 | acc2_neon[init] = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); 198 | } 199 | 200 | for(uint64_t d = 0; d < depth_words; d += DepthTile) { 201 | uint8x16_t a0, b0, a1, b1, a2, a3; 202 | 203 | a0 = vld1q_u8((uint8_t *) &Aptr[d + 0*depth_words]); 204 | a1 = vld1q_u8((uint8_t *) &Aptr[d + 1*depth_words]); 205 | a2 = vld1q_u8((uint8_t *) &Aptr[d + 2*depth_words]); 206 | a3 = vld1q_u8((uint8_t *) &Aptr[d + 3*depth_words]); 207 | b0 = vld1q_u8((uint8_t *) &BTptr[d]); 208 | b1 = vld1q_u8((uint8_t *) &BTptr[d + depth_words]); 209 | 210 | acc_neon[0] = vaddq_u8(acc_neon[0], vcntq_u8(vandq_u8(a0, b0))); 211 | acc_neon[1] = vaddq_u8(acc_neon[1], vcntq_u8(vandq_u8(a0, b1))); 212 | acc_neon[2] = vaddq_u8(acc_neon[2], vcntq_u8(vandq_u8(a1, b0))); 213 | acc_neon[3] = vaddq_u8(acc_neon[3], vcntq_u8(vandq_u8(a1, b1))); 214 | acc_neon[4] = vaddq_u8(acc_neon[4], vcntq_u8(vandq_u8(a2, b0))); 215 | acc_neon[5] = vaddq_u8(acc_neon[5], vcntq_u8(vandq_u8(a2, b1))); 216 | acc_neon[6] = vaddq_u8(acc_neon[6], vcntq_u8(vandq_u8(a3, b0))); 217 | acc_neon[7] = vaddq_u8(acc_neon[7], vcntq_u8(vandq_u8(a3, b1))); 218 | 219 | if((d & 7L) == 7L) { 220 | // hsum over 8-bit accumulators when end or overflow 221 | for(size_t init = 0; init < num_acc; init++) { 222 | acc2_neon[init] = vaddq_u64(acc2_neon[init], vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon[init])))); 223 | acc_neon[init] = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 224 | } 225 | } 226 | } 227 | /* move into regular accumulators */ 228 | for(size_t init = 0; init < num_acc; init++) { 229 | uint64_t tmp[2]; 230 | acc2_neon[init] = vaddq_u64(acc2_neon[init], vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc_neon[init])))); 231 | vst1q_u64(tmp, acc2_neon[init]); 232 | acc[init] = tmp[0] + tmp[1]; 233 | } 234 | for(uint64_t at = 0; at < Atile; at++) { 235 | for(uint64_t bt = 0; bt < BTtile; bt++) { 236 | if(((rBT + bt) < rowsBT_orig) && ((rA + at) < rowsA_orig)) { 237 | CT[(rBT + bt) * rowsA_orig + (rA + at)] += acc[at * BTtile + bt] * alpha; 238 | } 239 | } 240 | } 241 | } 242 | } 243 | // end of cache block 244 | } 245 | } 246 | } 247 | 248 | /* Bit-serial GEMM via a series of calls to gemmBinary. 249 | Note that rhs must be given in transposed form, and the result is also 250 | produced transposed. 251 | */ 252 | static void gemmBitSerial_neon_usingBinary(GEMMContext ctx) { 253 | // ensure that matrix shapes are compatible 254 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 255 | const uint64_t lhsbits = ctx.lhs.nbits; 256 | const uint64_t rhsbits = ctx.rhs.nbits; 257 | prepareAccumulators_neon(ctx); 258 | // call binary GEMM for each bit position 259 | for(uint64_t lbit = 0; lbit < lhsbits; lbit++) { 260 | bool neg_lhs = ctx.lhs.issigned && !ctx.lhs.isBipolar() && (lbit == lhsbits-1); 261 | for(uint64_t rbit = 0; rbit < rhsbits; rbit++) { 262 | bool neg_rhs = ctx.rhs.issigned && !ctx.rhs.isBipolar() && (rbit == rhsbits-1); 263 | bool neg = neg_rhs ^ neg_lhs; 264 | int32_t alpha = neg ? -(1 << (lbit+rbit)) : (1 << (lbit+rbit)); 265 | alpha = ctx.isBipolarTimesRegular() ? 2*alpha : alpha; 266 | gemmBinary_neon_L1_tile4x2x2( 267 | ctx.lhs.bitplaneptr(lbit), ctx.rhs.bitplaneptr(rbit), ctx.res, alpha, 268 | ctx.lhs.nrows_a, ctx.lhs.wordsPerRow(), ctx.rhs.nrows_a, 269 | ctx.lhs.nrows, ctx.rhs.nrows, ctx.lhsBlock, ctx.rhsBlock 270 | ); 271 | } 272 | } 273 | } 274 | 275 | // naive implementation for bipolar GEMM 276 | static void gemmBipolar_neon_naive(GEMMContext ctx) { 277 | // ensure that matrix shapes are compatible 278 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 279 | assert(ctx.lhs.isBipolar() && ctx.rhs.isBipolar()); 280 | const uint64_t out_rows = ctx.lhs.nrows; 281 | const uint64_t out_cols = ctx.rhs.nrows; 282 | const uint64_t depth = ctx.lhs.wordsPerRow(); 283 | prepareAccumulators_generic(ctx); 284 | for(uint64_t i = 0; i < out_cols; i++) { 285 | for(uint64_t j = 0; j < out_rows; j++) { 286 | uint64_t * ldata = ctx.lhs.rowptr(0, j); 287 | uint64_t * rdata = ctx.rhs.rowptr(0, i); 288 | // XNOR-popcount-accumulate over row pair. note that we do XOR-popcount 289 | // to save one instruction (no need to invert the XOR result). this is 290 | // accounted for in the correction afterwards. 291 | int32_t rowres = (int32_t) xor_popcount_neon(ldata, rdata, depth); 292 | // correction for sum of 1 and -1 bits 293 | ctx.res[i * ctx.lhs.nrows + j] += -2 * rowres + ctx.lhs.ncols; 294 | } 295 | } 296 | } 297 | 298 | // neon bipolar matrix times vector (GEMV) 299 | static void gemvBipolar_neon(GEMMContext ctx) { 300 | // ensure that matrix shapes are compatible 301 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 302 | assert(ctx.lhs.isBipolar() && ctx.rhs.isBipolar()); 303 | const uint64_t out_rows = ctx.lhs.nrows; 304 | const uint64_t depth = ctx.lhs.wordsPerRow(); 305 | prepareAccumulators_generic(ctx); 306 | for(uint64_t j = 0; j < out_rows; j++) { 307 | uint64_t * ldata = ctx.lhs.rowptr(0, j); 308 | uint64_t * rdata = ctx.rhs.rowptr(0, 0); 309 | // XNOR-popcount-accumulate over row pair. note that we do XOR-popcount 310 | // to save one instruction (no need to invert the XOR result). this is 311 | // accounted for in the correction afterwards. 312 | int32_t rowres = (int32_t) xor_popcount_neon(ldata, rdata, depth); 313 | // correction for sum of 1 and -1 bits 314 | ctx.res[j] += -2 * rowres + ctx.lhs.ncols; 315 | } 316 | } 317 | 318 | // neon bit serial matrix times vector (GEMV) 319 | static void gemvBitSerial_neon(GEMMContext ctx) { 320 | // ensure that matrix shapes are compatible 321 | assert(ctx.lhs.ncols == ctx.rhs.ncols); 322 | const uint64_t lhsbits = ctx.lhs.nbits; 323 | const uint64_t rhsbits = ctx.rhs.nbits; 324 | const uint64_t out_rows = ctx.lhs.nrows; 325 | const uint64_t depth = ctx.lhs.wordsPerRow(); 326 | uint64_t bpreg_scale = ctx.isBipolarTimesRegular() ? 1 : 0; 327 | prepareAccumulators_generic(ctx); 328 | for(uint64_t j = 0; j < out_rows; j++) { 329 | int32_t rowres = 0; 330 | for(uint64_t lbit = 0; lbit < lhsbits; lbit++) { 331 | bool neg_lhs = ctx.lhs.issigned && !ctx.lhs.isBipolar() && (lbit == lhsbits-1); 332 | for(uint64_t rbit = 0; rbit < rhsbits; rbit++) { 333 | bool neg_rhs = ctx.rhs.issigned && !ctx.rhs.isBipolar() && (rbit == rhsbits-1); 334 | uint64_t * ldata = ctx.lhs.rowptr(lbit, j); 335 | uint64_t * rdata = ctx.rhs.rowptr(rbit, 0); 336 | uint64_t andcard = (int32_t) and_popcount_neon(ldata, rdata, depth); 337 | // scale 338 | andcard = andcard << (lbit + rbit + bpreg_scale); 339 | // negate if needed 340 | rowres += (neg_lhs ^ neg_rhs) ? -andcard : andcard; 341 | } 342 | } 343 | ctx.res[j] += rowres; 344 | } 345 | } 346 | 347 | static void gemmBitSerial_neon(GEMMContext ctx) { 348 | if(ctx.isMatrixVector()) { 349 | if(ctx.isBipolarTimesBipolar()) { 350 | gemvBipolar_neon(ctx); 351 | } else { 352 | gemvBitSerial_neon(ctx); 353 | } 354 | } else { 355 | if(ctx.isBipolarTimesBipolar()) { 356 | gemmBipolar_neon_naive(ctx); 357 | } else { 358 | gemmBitSerial_neon_usingBinary(ctx); 359 | } 360 | } 361 | } 362 | -------------------------------------------------------------------------------- /gemmbitserial.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace gemmbitserial { 9 | 10 | // Utility function to increment-and-align "in" to "af" 11 | inline uint64_t alignTo(uint64_t in, uint64_t af) { 12 | if(in % af != 0) { 13 | return in + (af - (in % af)); 14 | } else { 15 | return in; 16 | } 17 | } 18 | 19 | class BitSerialMatrix { 20 | public: 21 | // static member functions for working with BitSerialMatrix 22 | 23 | /* Allocate buffer space for a BitSerialMatrix */ 24 | static BitSerialMatrix alloc(uint64_t nbits, uint64_t nrows, uint64_t ncols, bool issigned, uint64_t rowalign = 1, uint64_t colalign = 64) { 25 | BitSerialMatrix bsm; 26 | bsm.nbits = nbits; 27 | bsm.nrows = nrows; 28 | bsm.ncols = ncols; 29 | bsm.nrows_a = alignTo(nrows, rowalign); 30 | bsm.ncols_a = alignTo(ncols, colalign); 31 | bsm.issigned = issigned; 32 | uint64_t wordsPerBitplane = bsm.wordsPerBitplane(); 33 | bsm.data = new uint64_t[nbits * wordsPerBitplane]; 34 | return bsm; 35 | } 36 | 37 | /* Deallocate buffers for a BitSerialMatrix */ 38 | static void dealloc(BitSerialMatrix bsm) { 39 | delete [] bsm.data; 40 | } 41 | 42 | public: 43 | // actual member variables and functions of BitSerialMatrix instances 44 | bool issigned; // whether highest order bit pos is negative 45 | uint64_t nbits; // bits of precision 46 | uint64_t nrows; // number of real (actual) rows 47 | uint64_t ncols; // number of real (actual) columns 48 | uint64_t nrows_a; // number of allocated rows 49 | uint64_t ncols_a; // number of allocated columns 50 | uint64_t * data; // data buffer, layout [nbits][nrows_a][ncols_a/64] 51 | 52 | // print key statistics about BitSerialMatrix to stdout 53 | void printSummary() { 54 | std::cout << "BitSerialMatrix" << std::endl; 55 | std::cout << "Bits of precision: " << nbits << " signed: " << issigned << std::endl; 56 | std::cout << "Actual size: " << nrows << " x " << ncols << std::endl; 57 | std::cout << "Allocated size: " << nrows_a << " x " << ncols_a << std::endl; 58 | } 59 | 60 | void printHex() { 61 | for(int i = 0; i < nbits; i++) { 62 | std::cout << "Bit " << i << ":" << std::endl; 63 | for(int j = 0; j < nrows_a; j++) { 64 | for(int k = 0; k < ncols_a/64; k++) { 65 | std::cout << std::hex << word(i, j, k*64) << " " << std::dec; 66 | } 67 | std::cout << std::endl; 68 | } 69 | std::cout << std::endl; 70 | } 71 | } 72 | 73 | // return whether the matrix contains bipolar binary {-1, +1} values 74 | inline bool isBipolar() const { 75 | return nbits == 1 && issigned; 76 | } 77 | 78 | // number of storage words needed for each row 79 | inline uint64_t wordsPerRow() const { 80 | const uint64_t bitsPerWord = sizeof(uint64_t) * 8; 81 | return ncols_a / bitsPerWord; 82 | } 83 | 84 | // number of storage words needed for each bitplane (bit matrix) 85 | inline uint64_t wordsPerBitplane() const { 86 | return nrows_a * wordsPerRow(); 87 | } 88 | 89 | // get given bit. true if set, false if unset. 90 | inline bool get(uint64_t bit, uint64_t row, uint64_t col) { 91 | return ((word(bit, row, col) >> bitpos(col)) & 1L) == 1; 92 | } 93 | 94 | // set all bits to zero 95 | inline void clearAll() { 96 | memset(data, 0, nbits * wordsPerBitplane() * sizeof(uint64_t)); 97 | } 98 | 99 | // set given bit to one 100 | inline void set(uint64_t bit, uint64_t row, uint64_t col) { 101 | word(bit, row, col) |= (1L << bitpos(col)); 102 | } 103 | 104 | // set given bit to zero 105 | inline void unset(uint64_t bit, uint64_t row, uint64_t col) { 106 | word(bit, row, col) &= ~(1L << bitpos(col)); 107 | } 108 | 109 | // access the container word for a given bit 110 | inline uint64_t & word(uint64_t bit, uint64_t row, uint64_t col) { 111 | // right shift by log2(bits per word) to get word index 112 | uint64_t colw = col >> 6; 113 | return data[bit * wordsPerBitplane() + row * wordsPerRow() + colw]; 114 | } 115 | 116 | // get a pointer to a particular row 117 | inline uint64_t * rowptr(uint64_t bit, uint64_t row) { 118 | return &data[bit * wordsPerBitplane() + row * wordsPerRow()]; 119 | } 120 | 121 | // get a pointer to a particular bit plane 122 | inline uint64_t * bitplaneptr(uint64_t bit) { 123 | return &data[bit * wordsPerBitplane()]; 124 | } 125 | 126 | uint64_t bitpos(uint64_t col) { 127 | // return modulo 64 of col by using a bitmask 128 | return col & ((1 << 6) - 1); 129 | } 130 | 131 | /* 132 | Imports a regular matrix into this BitSerialMatrix. This is a slow, "naive" 133 | implementation. 134 | */ 135 | template 136 | void importRegular_naive(T * matrix, bool readColMajor=false) { 137 | this->clearAll(); 138 | for(uint64_t r = 0; r < this->nrows; r++) { 139 | for(uint64_t c = 0; c < this->ncols; c++) { 140 | T currentElem = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 141 | if(this->isBipolar()) { 142 | // use bipolar binary encoding: -1 and +1 only (sign) 143 | if(currentElem > 0) { 144 | this->set(0, r, c); 145 | } 146 | } else { 147 | // use two's complement 148 | uint8_t currentElem_uint8 = 0; 149 | if(this->issigned && currentElem < 0) { 150 | // convert to two's complement for this bitwidth 151 | currentElem_uint8 += (uint8_t)(1 << (this->nbits - 1)); 152 | currentElem_uint8 += (uint8_t)(currentElem + (1 << (this->nbits - 1))); 153 | } else { 154 | currentElem_uint8 = (uint8_t) currentElem; 155 | } 156 | for(uint64_t b = 0; b < this->nbits; b++) { 157 | if(currentElem_uint8 & (1 << b)) { 158 | this->set(b, r, c); 159 | } 160 | } 161 | } 162 | } 163 | } 164 | } 165 | 166 | /* 167 | Map given element of datatype T to uint8_t based on chosen quantization 168 | */ 169 | template 170 | inline uint8_t quantize(T currentElem) { 171 | uint8_t ret = 0; 172 | if(this->isBipolar()) { 173 | // use bipolar binary encoding: -1 and +1 only (sign) 174 | ret = currentElem > 0 ? 1 : 0; 175 | } else { 176 | // use two's complement 177 | if(this->issigned && currentElem < 0) { 178 | // convert to two's complement for this bitwidth 179 | ret += (uint8_t)(1 << (this->nbits - 1)); 180 | ret += (uint8_t)(currentElem + (1 << (this->nbits - 1))); 181 | } else { 182 | ret = (uint8_t) currentElem; 183 | } 184 | } 185 | return ret; 186 | } 187 | 188 | /* 189 | Import four bytes packed into a single uint32_t into row r, starting with 190 | column c. Intended for internal use. 191 | */ 192 | inline void import32As4x8(const uint32_t igroup, const uint64_t r, const uint64_t c) { 193 | // leftshift to align actual msb with leftmost bit position 194 | uint32_t group = igroup << (8 - this->nbits); 195 | // pack each bit position using Wojciech Mula's movmask approach: 196 | // http://0x80.pl/articles/scalar-sse-movmask.html 197 | for(uint64_t b = this->nbits; b-- > 0; ) { 198 | const uint32_t input = group & 0x80808080; 199 | const uint32_t mult = 0x02040810; 200 | const uint64_t result = (uint64_t)input * mult; 201 | const uint8_t res8 = (uint8_t)((result >> 32)); 202 | // put lowermost 4 bits of res8 into appropriate data buf pos 203 | this->word(b, r, c) |= (uint64_t)(res8 & 0x0f) << this->bitpos(c); 204 | // left shift for next bit group 205 | group = group << 1; 206 | } 207 | } 208 | 209 | /* Imports a regular matrix into this BitSerialMatrix, using bit twiddling 210 | tricks to go faster. 211 | */ 212 | template 213 | void importRegular(T * matrix, bool readColMajor=false) { 214 | this->clearAll(); 215 | const uint64_t cols_d4 = this->ncols - (this->ncols % 4); 216 | const uint64_t cols_rem = (this->ncols % 4); 217 | for(uint64_t r = 0; r < this->nrows; r++) { 218 | // handle conversion of 4-column chunks 219 | for(uint64_t c = 0; c < cols_d4; c+= 4) { 220 | // fetch four elements from row 221 | T e0 = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 222 | T e1 = readColMajor ? matrix[(c+1) * this->nrows + r] : matrix[r * this->ncols + (c+1)]; 223 | T e2 = readColMajor ? matrix[(c+2) * this->nrows + r] : matrix[r * this->ncols + (c+2)]; 224 | T e3 = readColMajor ? matrix[(c+3) * this->nrows + r] : matrix[r * this->ncols + (c+3)]; 225 | // cast all to uint8_t 226 | uint8_t b0 = this->quantize(e0); 227 | uint8_t b1 = this->quantize(e1); 228 | uint8_t b2 = this->quantize(e2); 229 | uint8_t b3 = this->quantize(e3); 230 | // pack into uint32_t and call import function 231 | uint32_t group = (b3 << 24) | (b2 << 16) | (b1 << 8) | b0; 232 | import32As4x8(group, r, c); 233 | } 234 | // fallback to naive to handle remainder of columns 235 | for(uint64_t c = cols_d4; c < this->ncols; c++) { 236 | T e0 = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 237 | uint8_t b0 = this->quantize(e0); 238 | for(uint64_t b = 0; b < this->nbits; b++) { 239 | if(b0 & (1 << b)) { 240 | this->set(b, r, c); 241 | } 242 | } 243 | } 244 | } 245 | } 246 | 247 | /* Specialized variant of importRegular for uint8_t, which needs no conversion. 248 | */ 249 | void importRegular(uint8_t * matrix, bool readColMajor=false) { 250 | this->clearAll(); 251 | const uint64_t cols_d4 = this->ncols - (this->ncols % 4); 252 | const uint64_t cols_rem = (this->ncols % 4); 253 | for(uint64_t r = 0; r < this->nrows; r++) { 254 | // handle conversion of 4-column chunks 255 | for(uint64_t c = 0; c < cols_d4; c+= 4) { 256 | // fetch four elements from row 257 | uint8_t b0 = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 258 | uint8_t b1 = readColMajor ? matrix[(c+1) * this->nrows + r] : matrix[r * this->ncols + (c+1)]; 259 | uint8_t b2 = readColMajor ? matrix[(c+2) * this->nrows + r] : matrix[r * this->ncols + (c+2)]; 260 | uint8_t b3 = readColMajor ? matrix[(c+3) * this->nrows + r] : matrix[r * this->ncols + (c+3)]; 261 | // pack into uint32_t and call import function 262 | uint32_t group = (b3 << 24) | (b2 << 16) | (b1 << 8) | b0; 263 | import32As4x8(group, r, c); 264 | } 265 | // fallback to naive to handle remainder of columns 266 | for(uint64_t c = cols_d4; c < this->ncols; c++) { 267 | uint8_t b0 = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 268 | for(uint64_t b = 0; b < this->nbits; b++) { 269 | if(b0 & (1 << b)) { 270 | this->set(b, r, c); 271 | } 272 | } 273 | } 274 | } 275 | } 276 | 277 | /* Imports a regular matrix after applying threshold quantization into this BitSerialMatrix. 278 | * The threshold array is assumped to have the shape thresholds[nThres][nrows], 279 | * and is assumed to be sorted s.t. the largest thresholds have the largest index. 280 | */ 281 | template 282 | void importRegularAndQuantize(T * matrix, T * thresholds, int nThres, bool readColMajor=false) { 283 | assert(!this->issigned); // threshold qnt. only makes sense for unsigned 284 | this->clearAll(); 285 | for(uint64_t r = 0; r < this->nrows; r++) { 286 | for(uint64_t c = 0; c < this->ncols; c++) { 287 | T currentElem = readColMajor ? matrix[c * this->nrows + r] : matrix[r * this->ncols + c]; 288 | // quantize this element by finding the index of the largest crossed 289 | // threshold 290 | for(int t = 0; t < nThres; t++) { 291 | if(currentElem <= thresholds[t * this->nrows + r]) { 292 | currentElem = t; 293 | break; 294 | } else if(t == nThres - 1) { 295 | // all thresholds crossed, set to largest quantization level 296 | currentElem = t + 1; 297 | } 298 | } 299 | // now convert to bit serial form 300 | uint8_t currentElem_uint8 = (uint8_t) currentElem;; 301 | for(uint64_t b = 0; b < this->nbits; b++) { 302 | if(currentElem_uint8 & (1 << b)) { 303 | this->set(b, r, c); 304 | } 305 | } 306 | } 307 | } 308 | } 309 | 310 | /* Convert this BitSerialMatrix back to a regular matrix. 311 | */ 312 | template 313 | void exportRegular(T * matrix) { 314 | for(uint64_t r = 0; r < this->nrows; r++) { 315 | for(uint64_t c = 0; c < this->ncols; c++) { 316 | if(this->isBipolar()) { 317 | matrix[r * this->ncols + c] = (T) this->get(0, r, c) ? +1 : -1; 318 | } else { 319 | T currentElem = 0; 320 | for(uint64_t b = 0; b < this->nbits; b++) { 321 | if(this->get(b, r, c)) { 322 | if((b == this->nbits-1) && this->issigned) { 323 | currentElem -= 1 << b; 324 | } else { 325 | currentElem += 1 << b; 326 | } 327 | } 328 | } 329 | matrix[r * this->ncols + c] = (T) currentElem; 330 | } 331 | } 332 | } 333 | } 334 | }; 335 | 336 | /* Utility function to find block size under the following assumptions: 337 | - size of lhs block + rhs block + result block <= cacheBits 338 | - no blocking along depth (i.e. only entire rows of dBits bits) 339 | - lhsMult and rhsMult determine the ratio for lhs and rhs rows in cache 340 | - returned lhsRows and rhsRows are divisible by lhsMult and rhsMult, respectively 341 | - each result elem takes bitsPerRes bits 342 | */ 343 | static void computeBlockSize(float lhsMult, float rhsMult, float cacheBits, float dBits, uint64_t & lhsBlock, uint64_t & rhsBlock) { 344 | float a = sizeof(int32_t) * lhsMult * rhsMult; 345 | float b = dBits*(lhsMult + rhsMult); 346 | float c = -cacheBits; 347 | float discr = sqrt(b*b - 4 * a * c); 348 | assert(discr > 0); 349 | int64_t x0 = floor((-b + discr) / (2*a)); 350 | int64_t x1 = floor((-b - discr) / (2*a)); 351 | int64_t x = x0 > x1 ? x0 : x1; 352 | if(x > 0) { 353 | lhsBlock = lhsMult * x; 354 | rhsBlock = rhsMult * x; 355 | } else { 356 | // some of the assumptions failed, return default block size 357 | lhsBlock = lhsMult; 358 | rhsBlock = rhsMult; 359 | } 360 | }; 361 | 362 | // rather naive, iterative search for a better block size 363 | // how could this be improved? 364 | static uint64_t finetuneBlockSize(uint64_t rows, uint64_t bs_max, uint64_t bs_div) { 365 | uint64_t best_cand = bs_max; 366 | uint64_t min_penalty = alignTo(rows, best_cand) - rows; 367 | for(uint64_t ccand = bs_max; ccand > bs_div; ccand = ccand - bs_div ) { 368 | if(ccand % bs_div == 0) { 369 | uint64_t penalty = alignTo(rows, ccand) - rows; 370 | if(penalty < min_penalty) { 371 | best_cand = ccand; 372 | min_penalty = penalty; 373 | } 374 | } 375 | } 376 | return best_cand; 377 | } 378 | 379 | class GEMMContext { 380 | public: 381 | BitSerialMatrix lhs, rhs; 382 | uint64_t lhsBlock, rhsBlock; 383 | int32_t * res; 384 | 385 | void printSummary() { 386 | std::cout << "GEMMContext" << std::endl; 387 | std::cout << "LHS: "; 388 | lhs.printSummary(); 389 | std::cout << "Block size: " << lhsBlock << std::endl; 390 | std::cout << "RHS: "; 391 | rhs.printSummary(); 392 | std::cout << "Block size: " << rhsBlock << std::endl; 393 | float actual_ops = 2*lhs.nrows*lhs.ncols*rhs.nrows; 394 | float alloc_ops = 2*lhs.nrows_a*lhs.ncols_a*rhs.nrows_a; 395 | std::cout << "Actual ops: " << actual_ops << std::endl; 396 | std::cout << "Allocated ops: " << alloc_ops << std::endl; 397 | std::cout << "Actual op percentage: " << 100*actual_ops/alloc_ops << std::endl; 398 | } 399 | 400 | inline bool isBipolarTimesRegular() const { 401 | return (lhs.isBipolar() && !rhs.isBipolar()) || (!lhs.isBipolar() && rhs.isBipolar()); 402 | } 403 | 404 | inline bool isBipolarTimesBipolar() const { 405 | return (lhs.isBipolar() && rhs.isBipolar()); 406 | } 407 | 408 | inline bool isMatrixVector() const { 409 | return rhs.nrows == 1; 410 | } 411 | }; 412 | 413 | // Base functionality for allocating a GEMM context. Do not use directly, 414 | // use the platform-provided allocGEMMContext instead. 415 | static GEMMContext allocGEMMContext_base( 416 | const uint64_t lhsRows, const uint64_t depth, const uint64_t rhsRows, 417 | const uint64_t lhsBits, const uint64_t rhsBits, const bool lhsSigned, 418 | const bool rhsSigned, const uint64_t regblock_lhs, const uint64_t regblock_d, 419 | const uint64_t regblock_rhs, const uint64_t cacheBits 420 | ) { 421 | GEMMContext ret; 422 | uint64_t depth_al = alignTo(depth, regblock_d*64); 423 | // use cache blocking; compute sizes 424 | computeBlockSize( 425 | regblock_lhs, regblock_rhs, cacheBits, depth_al, 426 | ret.lhsBlock, ret.rhsBlock 427 | ); 428 | if(ret.lhsBlock > lhsRows || ret.rhsBlock > rhsRows) { 429 | // use register blocking only 430 | ret.lhsBlock = alignTo(lhsRows, regblock_lhs); 431 | ret.rhsBlock = alignTo(rhsRows, regblock_rhs); 432 | } else { 433 | // see if there is too much wasted compute for current block sizes 434 | if((alignTo(lhsRows, ret.lhsBlock) - lhsRows) > 0.1*lhsRows) { 435 | ret.lhsBlock = finetuneBlockSize(lhsRows, ret.lhsBlock, regblock_lhs); 436 | } 437 | if((alignTo(rhsRows, ret.rhsBlock) - rhsRows) > 0.1*rhsRows) { 438 | ret.rhsBlock = finetuneBlockSize(rhsRows, ret.rhsBlock, regblock_rhs); 439 | } 440 | } 441 | // allocate aligned bit serial matrices 442 | ret.lhs = BitSerialMatrix::alloc( 443 | lhsBits, lhsRows, depth, lhsSigned, ret.lhsBlock, regblock_d*64 444 | ); 445 | ret.rhs = BitSerialMatrix::alloc( 446 | rhsBits, rhsRows, depth, rhsSigned, ret.rhsBlock, regblock_d*64 447 | ); 448 | // allocate result matrix. note that it is not aligned -- the 449 | // elements corresponding to alignment parts won't materialize. 450 | ret.res = new int32_t[lhsRows * rhsRows]; 451 | return ret; 452 | }; 453 | 454 | static void deallocGEMMContext(GEMMContext ctx) { 455 | delete [] ctx.res; 456 | BitSerialMatrix::dealloc(ctx.lhs); 457 | BitSerialMatrix::dealloc(ctx.rhs); 458 | }; 459 | 460 | 461 | 462 | // generic implementations using regular & and __builtin_popcountll 463 | #include "arch-generic.hpp" 464 | 465 | // select the implementations to be used based on defines 466 | #ifdef GEMMBITSERIAL_USE_ARM_NEON 467 | #warning "Compiling with ARM NEON" 468 | #include 469 | #include "arch-neon.hpp" 470 | // ARM NEON-specific implementations 471 | #define gemmBitSerial gemmBitSerial_neon 472 | #define allocGEMMContext allocGEMMContext_neon 473 | #define sumRows sumRows_neon 474 | #else 475 | #warning "Compiling using generic popcount" 476 | #define gemmBitSerial gemmBitSerial_generic 477 | #define allocGEMMContext allocGEMMContext_generic 478 | #define sumRows sumRows_generic 479 | #endif 480 | 481 | } 482 | --------------------------------------------------------------------------------