├── transform.lua ├── .gitignore ├── .gitmodules ├── init.lua ├── README.md ├── lib ├── qgemm.h └── qgemm.cpp ├── generic ├── Threshold.c ├── SpatialMaxPooling.c └── SpatialConvolutionMM.c ├── rocks └── nn8-scm-1.rockspec ├── test-precision.lua ├── LICENSE ├── init.c ├── test-speed.lua ├── CMakeLists.txt └── THNN.lua /transform.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lib/gemmlowp"] 2 | path = lib/gemmlowp 3 | url = https://github.com/google/gemmlowp.git 4 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require('torch') 2 | require('nn') 3 | require('THNN') 4 | 5 | -- temporary support for byte copy for nn.Module class 6 | function nn.Module:byte() 7 | return self:type('torch.ByteTensor') 8 | end 9 | 10 | 11 | return nn 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Low precision (8-bit) Torch nn library 2 | 3 | This experimental work uses [Google's low precision GEMM](https://github.com/google/gemmlowp) 4 | and only supports few modules. 5 | 6 | ### Install 7 | 8 | ```bash 9 | git clone https://github.com/jhjin/nn8 --recursive 10 | cd nn8 11 | luarocks make rocks/nn8-scm-1.rockspec 12 | ``` 13 | 14 | ### Test 15 | 16 | ```bash 17 | th test-precision.lua # small model 18 | th test-speed.lua # large model 19 | ``` 20 | -------------------------------------------------------------------------------- /lib/qgemm.h: -------------------------------------------------------------------------------- 1 | #ifndef QGEMM_H 2 | #define QGEMM_H 3 | 4 | 5 | #ifdef __cplusplus 6 | #include 7 | #include 8 | 9 | #include "gemmlowp/public/gemmlowp.h" 10 | 11 | extern "C" { 12 | #endif 13 | 14 | void THByteBlas_gemm8(uint8_t* c, uint8_t* c_bias, 15 | const uint8_t* a, const uint8_t* b, 16 | const int m, const int n, const int k, 17 | const int a_offset, const int b_offset, 18 | const int c_offset, const int c_mult, const int c_shift, 19 | const int use_relu); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /generic/Threshold.c: -------------------------------------------------------------------------------- 1 | void THNN_ByteThreshold_updateOutput( 2 | THNNState *state, 3 | THByteTensor *input, 4 | THByteTensor *output, 5 | uint8_t threshold, 6 | uint8_t val, 7 | bool inplace) 8 | { 9 | if (inplace) { 10 | TH_TENSOR_APPLY(uint8_t, input, 11 | if (*input_data <= threshold) 12 | *input_data = val; 13 | ); 14 | THByteTensor_set(output, input); 15 | } else { 16 | THByteTensor_resizeAs(output, input); 17 | TH_TENSOR_APPLY2(uint8_t, output, uint8_t, input, 18 | *output_data = (*input_data > threshold) ? *input_data : val; 19 | ); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /rocks/nn8-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "nn8" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/jhjin/nn8.git", 6 | } 7 | 8 | description = { 9 | summary = "Low Precision Neural Network package for Torch", 10 | detailed = [[ 11 | ]], 12 | homepage = "https://github.com/jhjin/nn8", 13 | license = "MIT" 14 | } 15 | 16 | dependencies = { 17 | "torch >= 7.0", 18 | } 19 | 20 | build = { 21 | type = "command", 22 | build_command = [[ 23 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 24 | ]], 25 | install_command = "cd build && $(MAKE) install" 26 | } 27 | -------------------------------------------------------------------------------- /test-precision.lua: -------------------------------------------------------------------------------- 1 | require('nn8') 2 | torch.manualSeed(2) 3 | torch.setdefaulttensortype('torch.FloatTensor') 4 | torch.setnumthreads(4) 5 | print('==> #threads: ', torch.getnumthreads()) 6 | 7 | 8 | local batchSize = 128 9 | local iC = 3 10 | local iH = 224 11 | local iW = iH 12 | 13 | 14 | local model32 = nn.Sequential() 15 | model32:add(nn.SpatialConvolutionMM(iC, 8, 3, 3)) 16 | model32:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 17 | model32:add(nn.Threshold(10, 10)) 18 | model32:get(1).weight:random(3):add(-1) 19 | model32:get(1).bias:random(3):add(-1) 20 | 21 | 22 | local model8 = nn.Sequential() 23 | for i = 1, #model32 do 24 | model8:add(model32:get(i):clone()) 25 | end 26 | model8:byte() 27 | 28 | 29 | local x32 = torch.FloatTensor(batchSize,iC,iH,iW):random(3):add(-1) 30 | local x8 = x32:byte() 31 | 32 | 33 | local t32 = torch.Timer() 34 | local y32 = model32:forward(x32) 35 | print('==> 32-bit: ', t32:time().real) 36 | 37 | 38 | local t8 = torch.Timer() 39 | local y8 = model8:forward(x8) 40 | print('==> 8-bit: ', t8:time().real) 41 | 42 | 43 | local diff = (y32 - y8:float()):abs() 44 | print('==> diff [max]: ', diff:max()) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Jonghoon Jin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /init.c: -------------------------------------------------------------------------------- 1 | #include "luaT.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #ifdef _OPENMP 7 | #include 8 | #endif 9 | 10 | typedef void THNNState; 11 | 12 | TH_API void THNN_ByteThreshold_updateOutput( 13 | THNNState *state, 14 | THByteTensor *input, 15 | THByteTensor *output, 16 | uint8_t threshold, 17 | uint8_t val, 18 | bool inplace); 19 | 20 | TH_API void THNN_ByteSpatialConvolutionMM_updateOutput( 21 | THNNState *state, 22 | THByteTensor *input, 23 | THByteTensor *output, 24 | THByteTensor *weight, 25 | THByteTensor *bias, 26 | THByteTensor *finput, 27 | THByteTensor *fgradInput, 28 | int kW, int kH, 29 | int dW, int dH, 30 | int padW, int padH); 31 | 32 | TH_API void THNN_ByteSpatialMaxPooling_updateOutput( 33 | THNNState *state, 34 | THByteTensor *input, 35 | THByteTensor *output, 36 | THByteTensor *indices, 37 | int kW, int kH, 38 | int dW, int dH, 39 | int padW, int padH, 40 | bool ceil_mode); 41 | 42 | #include "lib/qgemm.h" 43 | 44 | #include "generic/Threshold.c" 45 | #include "generic/SpatialConvolutionMM.c" 46 | #include "generic/SpatialMaxPooling.c" 47 | -------------------------------------------------------------------------------- /test-speed.lua: -------------------------------------------------------------------------------- 1 | require('nn8') 2 | torch.manualSeed(2) 3 | torch.setdefaulttensortype('torch.FloatTensor') 4 | torch.setnumthreads(4) 5 | print('==> #threads: ', torch.getnumthreads()) 6 | 7 | 8 | local batchSize = 64 9 | local iC = 3 10 | local iH = 224 11 | local iW = iH 12 | 13 | 14 | -- alex krizhevsky one weird trick (http://arxiv.org/abs/1404.5997) 15 | local model32 = nn.Sequential() 16 | model32:add(nn.SpatialConvolutionMM(3,64,11,11,4,4,2,2)) 17 | model32:add(nn.Threshold(10, 10)) 18 | model32:add(nn.SpatialMaxPooling(3,3,2,2)) 19 | model32:add(nn.SpatialConvolutionMM(64,192,5,5,1,1,2,2)) 20 | model32:add(nn.Threshold(10, 10)) 21 | model32:add(nn.SpatialMaxPooling(3,3,2,2)) 22 | model32:add(nn.SpatialConvolutionMM(192,384,3,3,1,1,1,1)) 23 | model32:add(nn.Threshold(10, 10)) 24 | model32:add(nn.SpatialConvolutionMM(384,256,3,3,1,1,1,1)) 25 | model32:add(nn.Threshold(10, 10)) 26 | model32:add(nn.SpatialConvolutionMM(256,256,3,3,1,1,1,1)) 27 | model32:add(nn.Threshold(10, 10)) 28 | model32:add(nn.SpatialMaxPooling(3,3,2,2)) 29 | model32:add(nn.View(256*6*6)) 30 | 31 | --[[ modules not supported yet 32 | model32:add(nn.Linear(256*6*6, 4096)) 33 | model32:add(nn.Threshold(10, 10)) 34 | model32:add(nn.Linear(4096, 4096)) 35 | model32:add(nn.Threshold(10, 10)) 36 | model32:add(nn.Linear(4096, 1000)) 37 | model32:add(nn.SoftMax()) 38 | ]] 39 | 40 | for _, v in pairs({1,4,7,9,11}) do 41 | model32:get(v).weight:random(3):add(-1) 42 | model32:get(v).bias:random(3):add(-1) 43 | end 44 | 45 | 46 | local model8 = nn.Sequential() 47 | for i = 1, #model32 do 48 | model8:add(model32:get(i):clone()) 49 | end 50 | model8:byte() 51 | 52 | 53 | local x32 = torch.FloatTensor(batchSize,iC,iH,iW):random(3):add(-1) 54 | local x8 = x32:byte() 55 | 56 | 57 | local t32 = torch.Timer() 58 | local y32 = model32:forward(x32) 59 | print('==> 32-bit: ', t32:time().real) 60 | 61 | 62 | local t8 = torch.Timer() 63 | local y8 = model8:forward(x8) 64 | print('==> 8-bit: ', t8:time().real) 65 | -------------------------------------------------------------------------------- /lib/qgemm.cpp: -------------------------------------------------------------------------------- 1 | #include "qgemm.h" 2 | 3 | void THByteBlas_gemm8(uint8_t* c, uint8_t* c_bias, 4 | const uint8_t* a, const uint8_t* b, 5 | const int m, const int n, const int k, 6 | const int a_offset, const int b_offset, 7 | const int c_offset, const int c_mult, const int c_shift, 8 | const int use_relu) 9 | { 10 | // quantize-down, unclamped (but scaled) int32's 11 | gemmlowp::OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage; 12 | quantize_down_stage.result_offset = c_offset; 13 | quantize_down_stage.result_mult_int = c_mult; 14 | quantize_down_stage.result_shift = c_shift; 15 | 16 | // clamp-and-cast-to-uint8 17 | gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; 18 | 19 | // clamp min/max bounds 20 | gemmlowp::OutputStageClamp clamp_stage{0, 255}; 21 | if (use_relu) clamp_stage.min = 128; 22 | 23 | typedef gemmlowp::VectorMap ColVectorMap; 24 | ColVectorMap col_vector_map(c_bias, m); 25 | gemmlowp::OutputStageBiasAddition col_bias_addition_stage; 26 | col_bias_addition_stage.bias_vector = col_vector_map; 27 | 28 | // set pipeline after gemm 29 | auto bias_clamp_quantize_cast_pipeline = 30 | std::make_tuple(col_bias_addition_stage, 31 | clamp_stage, 32 | quantize_down_stage, 33 | saturating_cast_stage); 34 | 35 | // init gemmlowp context and storage 36 | gemmlowp::GemmContext context; 37 | const gemmlowp::MatrixMap a_(a, m, k, k); 38 | const gemmlowp::MatrixMap b_(b, k, n, n); 39 | gemmlowp::MatrixMap c_(c, m, n, n); 40 | 41 | // gemm and output pipeline 42 | gemmlowp::GemmWithOutputPipeline( 43 | &context, a_, b_, &c_, a_offset, b_offset, bias_clamp_quantize_cast_pipeline); 44 | } 45 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.6) 3 | 4 | FIND_PACKAGE(Torch REQUIRED) 5 | 6 | # gemmlowp requires c++11 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 8 | if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm.*") 9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") 10 | endif() 11 | # Flags 12 | # When using MSVC 13 | IF(MSVC) 14 | # we want to respect the standard, and we are bored of those **** . 15 | ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) 16 | ENDIF(MSVC) 17 | 18 | # OpenMP support? 19 | SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") 20 | IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) 21 | EXEC_PROGRAM (uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) 22 | STRING (REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) 23 | MESSAGE (STATUS "MAC OS Darwin Version: ${DARWIN_VERSION}") 24 | IF (DARWIN_VERSION GREATER 9) 25 | SET(APPLE_OPENMP_SUCKS 1) 26 | ENDIF (DARWIN_VERSION GREATER 9) 27 | EXECUTE_PROCESS (COMMAND ${CMAKE_C_COMPILER} -dumpversion 28 | OUTPUT_VARIABLE GCC_VERSION) 29 | IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2) 30 | MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)") 31 | MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP") 32 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") 33 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) 34 | ENDIF () 35 | ENDIF () 36 | 37 | IF (WITH_OPENMP) 38 | FIND_PACKAGE(OpenMP) 39 | IF(OPENMP_FOUND) 40 | MESSAGE(STATUS "Compiling with OpenMP support") 41 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 42 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 43 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 44 | ENDIF(OPENMP_FOUND) 45 | ENDIF (WITH_OPENMP) 46 | 47 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 48 | 49 | SET(src init.c lib/qgemm.cpp lib/gemmlowp/eight_bit_int_gemm/eight_bit_int_gemm.cc) 50 | 51 | FILE(GLOB luasrc *.lua) 52 | 53 | ADD_TORCH_PACKAGE(nn8 "${src}" "${luasrc}") 54 | 55 | TARGET_LINK_LIBRARIES(nn8 luaT TH) 56 | -------------------------------------------------------------------------------- /THNN.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | 3 | local THNN = {} 4 | 5 | local generic_THNN_h = [[ 6 | TH_API void THNN_(Threshold_updateOutput)( 7 | THNNState *state, 8 | THTensor *input, 9 | THTensor *output, 10 | real threshold, 11 | real val, 12 | bool inplace); 13 | 14 | TH_API void THNN_(SpatialConvolutionMM_updateOutput)( 15 | THNNState *state, 16 | THTensor *input, 17 | THTensor *output, 18 | THTensor *weight, 19 | THTensor *bias, 20 | THTensor *finput, 21 | THTensor *fgradInput, 22 | int kW, int kH, 23 | int dW, int dH, 24 | int padW, int padH); 25 | 26 | TH_API void THNN_(SpatialMaxPooling_updateOutput)( 27 | THNNState *state, 28 | THTensor *input, 29 | THTensor *output, 30 | THTensor *indices, 31 | int kW, int kH, 32 | int dW, int dH, 33 | int padW, int padH, 34 | bool ceil_mode); 35 | ]] 36 | 37 | -- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h 38 | local base_declarations = [[ 39 | typedef void THNNState; 40 | 41 | typedef struct { 42 | unsigned long the_initial_seed; 43 | int left; 44 | int seeded; 45 | unsigned long next; 46 | unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */ 47 | double normal_x; 48 | double normal_y; 49 | double normal_rho; 50 | int normal_is_valid; 51 | } THGenerator; 52 | ]] 53 | 54 | -- polyfill for LUA 5.1 55 | if not package.searchpath then 56 | local sep = package.config:sub(1,1) 57 | function package.searchpath(mod, path) 58 | mod = mod:gsub('%.', sep) 59 | for m in path:gmatch('[^;]+') do 60 | local nm = m:gsub('?', mod) 61 | local f = io.open(nm, 'r') 62 | if f then 63 | f:close() 64 | return nm 65 | end 66 | end 67 | end 68 | end 69 | 70 | -- load libTHNN 71 | THNN.C = ffi.load(package.searchpath('libnn8', package.cpath)) 72 | 73 | ffi.cdef(base_declarations) 74 | 75 | -- expand macros, allow to use original lines from lib/THNN/generic/THNN.h 76 | local preprocessed = string.gsub(generic_THNN_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1') 77 | 78 | local replacements = 79 | { 80 | { 81 | ['TYPE'] = 'Byte', 82 | ['real'] = 'uint8_t', 83 | ['THTensor'] = 'THByteTensor', 84 | ['THIndexTensor'] = 'THLongTensor', 85 | ['THIntegerTensor'] = 'THIntTensor', 86 | ['THIndex_t'] = 'long', 87 | ['THInteger_t'] = 'int' 88 | }, 89 | } 90 | 91 | for i=1,#replacements do 92 | local r = replacements[i] 93 | local s = preprocessed 94 | for k,v in pairs(r) do 95 | s = string.gsub(s, k, v) 96 | end 97 | ffi.cdef(s) 98 | end 99 | 100 | THNN.NULL = ffi.NULL or nil 101 | 102 | function THNN.getState() 103 | return ffi.NULL or nil 104 | end 105 | 106 | function THNN.optionalTensor(t) 107 | return t and t:cdata() or THNN.NULL 108 | end 109 | 110 | local function extract_function_names(s) 111 | local t = {} 112 | for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do 113 | t[#t+1] = n 114 | end 115 | return t 116 | end 117 | 118 | function THNN.bind(lib, base_names, type_name, state_getter) 119 | local ftable = {} 120 | local prefix = 'THNN_' .. type_name 121 | for i,n in ipairs(base_names) do 122 | -- use pcall since some libs might not support all functions (e.g. cunn) 123 | local ok,v = pcall(function() return lib[prefix .. n] end) 124 | if ok then 125 | ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state 126 | else 127 | print('not found: ' .. prefix .. n .. v) 128 | end 129 | end 130 | return ftable 131 | end 132 | 133 | -- build function table 134 | local function_names = extract_function_names(generic_THNN_h) 135 | 136 | THNN.kernels = {} 137 | THNN.kernels['torch.ByteTensor'] = THNN.bind(THNN.C, function_names, 'Byte', THNN.getState) 138 | 139 | torch.getmetatable('torch.ByteTensor').THNN = THNN.kernels['torch.ByteTensor'] 140 | 141 | function THNN.runKernel(f, type, ...) 142 | local ftable = THNN.kernels[type] 143 | if not ftable then 144 | error('Unsupported tensor type: '..type) 145 | end 146 | local f = ftable[f] 147 | if not f then 148 | error(string.format("Function '%s' not found for tensor type '%s'.", f, type)) 149 | end 150 | f(...) 151 | end 152 | 153 | return THNN 154 | -------------------------------------------------------------------------------- /generic/SpatialMaxPooling.c: -------------------------------------------------------------------------------- 1 | static void THNN_ByteSpatialMaxPooling_updateOutput_frame(uint8_t *input_p, uint8_t *output_p, 2 | uint8_t *ind_p, 3 | long nslices, 4 | long iwidth, long iheight, 5 | long owidth, long oheight, 6 | int kW, int kH, int dW, int dH, 7 | int padW, int padH) 8 | { 9 | long k; 10 | #pragma omp parallel for private(k) 11 | for (k = 0; k < nslices; k++) 12 | { 13 | /* loop over output */ 14 | long i, j; 15 | uint8_t *ip = input_p + k*iwidth*iheight; 16 | for(i = 0; i < oheight; i++) 17 | { 18 | for(j = 0; j < owidth; j++) 19 | { 20 | long hstart = i * dH - padH; 21 | long wstart = j * dW - padW; 22 | long hend = fminf(hstart + kH, iheight); 23 | long wend = fminf(wstart + kW, iwidth); 24 | hstart = fmaxf(hstart, 0); 25 | wstart = fmaxf(wstart, 0); 26 | 27 | /* local pointers */ 28 | uint8_t *op = output_p + k*owidth*oheight + i*owidth + j; 29 | uint8_t *indp = ind_p + k*owidth*oheight + i*owidth + j; 30 | 31 | /* compute local max: */ 32 | long maxindex = -1; 33 | uint8_t maxval = 0; 34 | long tcntr = 0; 35 | long x,y; 36 | for(y = hstart; y < hend; y++) 37 | { 38 | for(x = wstart; x < wend; x++) 39 | { 40 | tcntr = y*iwidth + x; 41 | uint8_t val = *(ip + tcntr); 42 | if (val > maxval) 43 | { 44 | maxval = val; 45 | maxindex = tcntr; 46 | } 47 | } 48 | } 49 | 50 | /* set output to local max */ 51 | *op = maxval; 52 | 53 | /* store location of max */ 54 | *indp = maxindex + 1; 55 | } 56 | } 57 | } 58 | } 59 | 60 | void THNN_ByteSpatialMaxPooling_updateOutput( 61 | THNNState *state, 62 | THByteTensor *input, 63 | THByteTensor *output, 64 | THByteTensor *indices, 65 | int kW, 66 | int kH, 67 | int dW, 68 | int dH, 69 | int padW, 70 | int padH, 71 | bool ceil_mode) 72 | { 73 | int dimw = 2; 74 | int dimh = 1; 75 | long nbatch = 1; 76 | long nslices; 77 | long iheight; 78 | long iwidth; 79 | long oheight; 80 | long owidth; 81 | uint8_t *input_data; 82 | uint8_t *output_data; 83 | uint8_t *indices_data; 84 | 85 | 86 | THArgCheck(input->nDimension == 3 || input->nDimension == 4 , 2, "3D or 4D (batch mode) tensor expected"); 87 | 88 | if (input->nDimension == 4) 89 | { 90 | nbatch = input->size[0]; 91 | dimw++; 92 | dimh++; 93 | } 94 | THArgCheck(input->size[dimw] >= kW - padW && input->size[dimh] >= kH - padH, 2, "input image smaller than kernel size"); 95 | 96 | THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size"); 97 | 98 | /* sizes */ 99 | nslices = input->size[dimh-1]; 100 | iheight = input->size[dimh]; 101 | iwidth = input->size[dimw]; 102 | if (ceil_mode) 103 | { 104 | oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1; 105 | owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1; 106 | } 107 | else 108 | { 109 | oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1; 110 | owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1; 111 | } 112 | 113 | if (padW || padH) 114 | { 115 | // ensure that the last pooling starts inside the image 116 | if ((oheight - 1)*dH >= iheight + padH) 117 | --oheight; 118 | if ((owidth - 1)*dW >= iwidth + padW) 119 | --owidth; 120 | } 121 | 122 | /* get contiguous input */ 123 | input = THByteTensor_newContiguous(input); 124 | 125 | /* resize output */ 126 | if (input->nDimension == 3) 127 | { 128 | THByteTensor_resize3d(output, nslices, oheight, owidth); 129 | /* indices will contain the locations for each output point */ 130 | THByteTensor_resize3d(indices, nslices, oheight, owidth); 131 | 132 | input_data = THByteTensor_data(input); 133 | output_data = THByteTensor_data(output); 134 | indices_data = THByteTensor_data(indices); 135 | 136 | THNN_ByteSpatialMaxPooling_updateOutput_frame(input_data, output_data, 137 | indices_data, 138 | nslices, 139 | iwidth, iheight, 140 | owidth, oheight, 141 | kW, kH, dW, dH, 142 | padW, padH); 143 | } 144 | else 145 | { 146 | long p; 147 | 148 | THByteTensor_resize4d(output, nbatch, nslices, oheight, owidth); 149 | /* indices will contain the locations for each output point */ 150 | THByteTensor_resize4d(indices, nbatch, nslices, oheight, owidth); 151 | 152 | input_data = THByteTensor_data(input); 153 | output_data = THByteTensor_data(output); 154 | indices_data = THByteTensor_data(indices); 155 | 156 | #pragma omp parallel for private(p) 157 | for (p = 0; p < nbatch; p++) 158 | { 159 | THNN_ByteSpatialMaxPooling_updateOutput_frame(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight, 160 | indices_data+p*nslices*owidth*oheight, 161 | nslices, 162 | iwidth, iheight, 163 | owidth, oheight, 164 | kW, kH, dW, dH, 165 | padW, padH); 166 | } 167 | } 168 | 169 | /* cleanup */ 170 | THByteTensor_free(input); 171 | } 172 | -------------------------------------------------------------------------------- /generic/SpatialConvolutionMM.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/SpatialConvolutionMM.c" 3 | 4 | void THNN_Byteunfolded_copy(THByteTensor *finput, THByteTensor *input, 5 | int kW, int kH, 6 | int dW, int dH, 7 | int padW, int padH, 8 | int nInputPlane, 9 | int inputWidth, int inputHeight, 10 | int outputWidth, int outputHeight) 11 | { 12 | long k; 13 | uint8_t *input_data = THByteTensor_data(input); 14 | uint8_t *finput_data = THByteTensor_data(finput); 15 | 16 | #pragma omp parallel for private(k) 17 | for(k = 0; k < nInputPlane*kH*kW; k++) { 18 | long nip = k / (kH*kW); 19 | long rest = k % (kH*kW); 20 | long kh = rest / kW; 21 | long kw = rest % kW; 22 | long x,y; 23 | long ix,iy; 24 | uint8_t *dst = finput_data + nip*(kH*kW*outputHeight*outputWidth) + kh*(kW*outputHeight*outputWidth) + kw*(outputHeight*outputWidth); 25 | uint8_t *src = input_data + nip*(inputHeight*inputWidth); 26 | if (padW > 0 || padH > 0) { 27 | long lpad,rpad; 28 | for(y = 0; y < outputHeight; y++) { 29 | iy = (y*dH - padH + kh); 30 | if (iy < 0 || iy >= inputHeight) { 31 | memset(dst+y*outputWidth, 0, sizeof(uint8_t)*outputWidth); 32 | } else { 33 | if (dW==1){ 34 | ix = (0 - padW + kw); 35 | lpad = fmaxf(0,padW-kw); 36 | rpad = fmaxf(0,padW-(kW-kw-1)); 37 | if (outputWidth-rpad-lpad <= 0) { 38 | memset(dst+(y*outputWidth), 0, sizeof(uint8_t)*outputWidth); 39 | } else { 40 | if (lpad > 0) memset(dst+y*outputWidth, 0, sizeof(uint8_t)*lpad); 41 | memcpy(dst+(y*outputWidth+lpad), src+(iy*inputWidth+ix+lpad), sizeof(uint8_t)*(outputWidth-rpad-lpad)); 42 | if (rpad > 0) memset(dst+y*outputWidth + outputWidth - rpad, 0, sizeof(uint8_t)*rpad); 43 | } 44 | } 45 | else{ 46 | for (x=0; x= inputWidth) 49 | memset(dst+(y*outputWidth+x), 0, sizeof(uint8_t)*1); 50 | else 51 | memcpy(dst+(y*outputWidth+x), src+(iy*inputWidth+ix), sizeof(uint8_t)*(1)); 52 | } 53 | } 54 | } 55 | } 56 | } else { 57 | for(y = 0; y < outputHeight; y++) { 58 | iy = (y*dH + kh); 59 | ix = (0 + kw); 60 | if (dW == 1) 61 | memcpy(dst+(y*outputWidth), src+(iy*inputWidth+ix), sizeof(uint8_t)*outputWidth); 62 | else{ 63 | for (x=0; xstorage, output->storageOffset, 84 | nOutputPlane, -1, 85 | outputHeight*outputWidth, -1); 86 | 87 | // gemmlowp considers weight/bias altogether 88 | THByteBlas_gemm8(THByteTensor_data(output2d), 89 | THByteTensor_data(bias), 90 | THByteTensor_data(weight), 91 | THByteTensor_data(finput), 92 | weight->size[0], finput->size[1], weight->size[1], 93 | 0, 0, 0, 1, 0, 0); 94 | 95 | THByteTensor_free(output2d); 96 | } 97 | 98 | void THNN_ByteSpatialConvolutionMM_updateOutput( 99 | THNNState *state, 100 | THByteTensor *input, 101 | THByteTensor *output, 102 | THByteTensor *weight, 103 | THByteTensor *bias, 104 | THByteTensor *finput, 105 | THByteTensor *fgradInput, 106 | int kW, 107 | int kH, 108 | int dW, 109 | int dH, 110 | int padW, 111 | int padH) 112 | { 113 | int dimf = 0; 114 | int dimw = 2; 115 | int dimh = 1; 116 | 117 | long nInputPlane; 118 | long inputWidth; 119 | long inputHeight; 120 | long nOutputPlane; 121 | long outputWidth; 122 | long outputHeight; 123 | 124 | THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D(batch mode) tensor expected"); 125 | 126 | 127 | if (input->nDimension == 4) { 128 | dimf++; 129 | dimw++; 130 | dimh++; 131 | } 132 | 133 | nInputPlane = input->size[dimf]; 134 | inputWidth = input->size[dimw]; 135 | inputHeight = input->size[dimh]; 136 | nOutputPlane = weight->size[0]; 137 | outputWidth = (inputWidth + 2*padW - kW) / dW + 1; 138 | outputHeight = (inputHeight + 2*padH - kH) / dH + 1; 139 | 140 | if (outputWidth < 1 || outputHeight < 1) 141 | THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small", 142 | nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth); 143 | 144 | if(input->nDimension == 3) 145 | { 146 | THByteTensor_resize2d(finput, kW*kH*nInputPlane, outputHeight*outputWidth); 147 | THByteTensor_resize3d(output, nOutputPlane, outputHeight, outputWidth); 148 | 149 | THNN_ByteSpatialConvolutionMM_updateOutput_frame(input, output, weight, bias, finput, 150 | kW, kH, dW, dH, padW, padH, 151 | nInputPlane, inputWidth, inputHeight, 152 | nOutputPlane, outputWidth, outputHeight); 153 | } 154 | else 155 | { 156 | long T = input->size[0]; 157 | long t; 158 | 159 | THByteTensor_resize3d(finput, T, kW*kH*nInputPlane, outputHeight*outputWidth); 160 | THByteTensor_resize4d(output, T, nOutputPlane, outputHeight, outputWidth); 161 | 162 | #pragma omp parallel for private(t) 163 | for(t = 0; t < T; t++) 164 | { 165 | THByteTensor *input_t = THByteTensor_newSelect(input, 0, t); 166 | THByteTensor *output_t = THByteTensor_newSelect(output, 0, t); 167 | THByteTensor *finput_t = THByteTensor_newSelect(finput, 0, t); 168 | 169 | THNN_ByteSpatialConvolutionMM_updateOutput_frame(input_t, output_t, weight, bias, finput_t, 170 | kW, kH, dW, dH, padW, padH, 171 | nInputPlane, inputWidth, inputHeight, 172 | nOutputPlane, outputWidth, outputHeight); 173 | 174 | THByteTensor_free(input_t); 175 | THByteTensor_free(output_t); 176 | THByteTensor_free(finput_t); 177 | } 178 | } 179 | } 180 | 181 | #endif 182 | --------------------------------------------------------------------------------