├── README ├── RoI_pooling ├── __init__.py ├── _ext │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-35.pyc │ └── roi_pooling │ │ ├── __init__.py │ │ ├── __pycache__ │ │ └── __init__.cpython-35.pyc │ │ └── _roi_pooling.so ├── build.py ├── roi_pool.py └── src │ ├── cuda │ ├── roi_pooling_kernel.cu │ ├── roi_pooling_kernel.cu.o │ └── roi_pooling_kernel.h │ ├── roi_pooling.c │ ├── roi_pooling.h │ ├── roi_pooling_cuda.c │ └── roi_pooling_cuda.h ├── roi_cupy.py ├── roi_module.py ├── speed.py └── speed1.py /README: -------------------------------------------------------------------------------- 1 | use_cuda: True, has_backward: True 2 | method0: 0.0344547653198242, batch_size: 1, size: 50, num_rois: 300 3 | method1: 0.1322056961059570, batch_size: 1, size: 50, num_rois: 300 4 | method2: 0.1307379817962646, batch_size: 1, size: 50, num_rois: 300 5 | method3: 0.2016681671142578, batch_size: 1, size: 50, num_rois: 300 6 | 7 | # Referenc: https://github.com/blackyang/compare_roi_pooling_speed 8 | # see also: http://www.cnblogs.com/king-lps/p/9026798.html 9 | -------------------------------------------------------------------------------- /RoI_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_pool import * 2 | -------------------------------------------------------------------------------- /RoI_pooling/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SirLPS/roi_pooling/1b2b3f953ed10d101e5f4283c56eeebf90b22ef5/RoI_pooling/_ext/__init__.py -------------------------------------------------------------------------------- /RoI_pooling/_ext/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SirLPS/roi_pooling/1b2b3f953ed10d101e5f4283c56eeebf90b22ef5/RoI_pooling/_ext/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /RoI_pooling/_ext/roi_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._roi_pooling import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | locals[symbol] = _wrap_function(fn, _ffi) 10 | __all__.append(symbol) 11 | 12 | _import_symbols(locals()) 13 | -------------------------------------------------------------------------------- /RoI_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SirLPS/roi_pooling/1b2b3f953ed10d101e5f4283c56eeebf90b22ef5/RoI_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /RoI_pooling/_ext/roi_pooling/_roi_pooling.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SirLPS/roi_pooling/1b2b3f953ed10d101e5f4283c56eeebf90b22ef5/RoI_pooling/_ext/roi_pooling/_roi_pooling.so -------------------------------------------------------------------------------- /RoI_pooling/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | 6 | sources = ['src/roi_pooling.c'] 7 | headers = ['src/roi_pooling.h'] 8 | defines = [] 9 | with_cuda = False 10 | 11 | if torch.cuda.is_available(): 12 | print('Including CUDA code.') 13 | sources += ['src/roi_pooling_cuda.c'] 14 | headers += ['src/roi_pooling_cuda.h'] 15 | defines += [('WITH_CUDA', None)] 16 | with_cuda = True 17 | 18 | this_file = os.path.dirname(os.path.realpath(__file__)) 19 | print(this_file) 20 | extra_objects = ['src/cuda/roi_pooling_kernel.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi = create_extension( 24 | '_ext.roi_pooling', 25 | headers=headers, 26 | sources=sources, 27 | define_macros=defines, 28 | relative_to=__file__, 29 | with_cuda=with_cuda, 30 | extra_objects=extra_objects 31 | ) 32 | 33 | if __name__ == '__main__': 34 | ffi.build() 35 | -------------------------------------------------------------------------------- /RoI_pooling/roi_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ._ext import roi_pooling 4 | 5 | 6 | class RoIPoolFunction(Function): 7 | def __init__(self, pooled_height, pooled_width, spatial_scale): 8 | self.pooled_width = int(pooled_width) 9 | self.pooled_height = int(pooled_height) 10 | self.spatial_scale = float(spatial_scale) 11 | self.output = None 12 | self.argmax = None 13 | self.rois = None 14 | self.feature_size = None 15 | 16 | def forward(self, features, rois): 17 | batch_size, num_channels, data_height, data_width = features.size() 18 | num_rois = rois.size()[0] 19 | output = torch.zeros(num_rois, num_channels, self.pooled_height, self.pooled_width) 20 | argmax = torch.IntTensor(num_rois, num_channels, self.pooled_height, self.pooled_width).zero_() 21 | 22 | if not features.is_cuda: 23 | _features = features.permute(0, 2, 3, 1) 24 | roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale, 25 | _features, rois, output) 26 | # output = output.cuda() 27 | else: 28 | output = output.cuda() 29 | argmax = argmax.cuda() 30 | roi_pooling.roi_pooling_forward_cuda(self.pooled_height, self.pooled_width, self.spatial_scale, 31 | features, rois, output, argmax) 32 | self.output = output 33 | self.argmax = argmax 34 | self.rois = rois 35 | self.feature_size = features.size() 36 | 37 | return output 38 | 39 | def backward(self, grad_output): 40 | assert(self.feature_size is not None and grad_output.is_cuda) 41 | 42 | batch_size, num_channels, data_height, data_width = self.feature_size 43 | 44 | grad_input = torch.zeros(batch_size, num_channels, data_height, data_width).cuda() 45 | roi_pooling.roi_pooling_backward_cuda(self.pooled_height, self.pooled_width, self.spatial_scale, 46 | grad_output, self.rois, grad_input, self.argmax) 47 | 48 | # print grad_input 49 | 50 | return grad_input, None 51 | 52 | 53 | class RoIPool(torch.nn.Module): 54 | def __init__(self, pooled_height, pooled_width, spatial_scale): 55 | super(RoIPool, self).__init__() 56 | 57 | self.pooled_width = int(pooled_width) 58 | self.pooled_height = int(pooled_height) 59 | self.spatial_scale = float(spatial_scale) 60 | 61 | def forward(self, features, rois): 62 | return RoIPoolFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois) 63 | -------------------------------------------------------------------------------- /RoI_pooling/src/cuda/roi_pooling_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include 6 | #include 7 | #include 8 | #include "roi_pooling_kernel.h" 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 12 | i += blockDim.x * gridDim.x) 13 | 14 | 15 | __global__ void ROIPoolForward(const int nthreads, const float* bottom_data, 16 | const float spatial_scale, const int height, const int width, 17 | const int channels, const int pooled_height, const int pooled_width, 18 | const float* bottom_rois, float* top_data, int* argmax_data) 19 | { 20 | CUDA_1D_KERNEL_LOOP(index, nthreads) 21 | { 22 | // (n, c, ph, pw) is an element in the pooled output 23 | int n = index; 24 | int pw = n % pooled_width; 25 | n /= pooled_width; 26 | int ph = n % pooled_height; 27 | n /= pooled_height; 28 | int c = n % channels; 29 | n /= channels; 30 | 31 | bottom_rois += n * 5; 32 | int roi_batch_ind = bottom_rois[0]; 33 | int roi_start_w = round(bottom_rois[1] * spatial_scale); 34 | int roi_start_h = round(bottom_rois[2] * spatial_scale); 35 | int roi_end_w = round(bottom_rois[3] * spatial_scale); 36 | int roi_end_h = round(bottom_rois[4] * spatial_scale); 37 | 38 | // Force malformed ROIs to be 1x1 39 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 40 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 41 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 42 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 43 | 44 | int hstart = (int)(floor((float)(ph) * bin_size_h)); 45 | int wstart = (int)(floor((float)(pw) * bin_size_w)); 46 | int hend = (int)(ceil((float)(ph + 1) * bin_size_h)); 47 | int wend = (int)(ceil((float)(pw + 1) * bin_size_w)); 48 | 49 | // Add roi offsets and clip to input boundaries 50 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), height); 51 | hend = fminf(fmaxf(hend + roi_start_h, 0), height); 52 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), width); 53 | wend = fminf(fmaxf(wend + roi_start_w, 0), width); 54 | bool is_empty = (hend <= hstart) || (wend <= wstart); 55 | 56 | // Define an empty pooling region to be zero 57 | float maxval = is_empty ? 0 : -FLT_MAX; 58 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd 59 | int maxidx = -1; 60 | bottom_data += roi_batch_ind * channels * height * width; 61 | for (int h = hstart; h < hend; ++h) { 62 | for (int w = wstart; w < wend; ++w) { 63 | // int bottom_index = (h * width + w) * channels + c; 64 | int bottom_index = (c * height + h) * width + w; 65 | if (bottom_data[bottom_index] > maxval) { 66 | maxval = bottom_data[bottom_index]; 67 | maxidx = bottom_index; 68 | } 69 | } 70 | } 71 | top_data[index] = maxval; 72 | if (argmax_data != NULL) 73 | argmax_data[index] = maxidx; 74 | } 75 | } 76 | 77 | 78 | int ROIPoolForwardLaucher( 79 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 80 | const int width, const int channels, const int pooled_height, 81 | const int pooled_width, const float* bottom_rois, 82 | float* top_data, int* argmax_data, cudaStream_t stream) 83 | { 84 | const int kThreadsPerBlock = 1024; 85 | const int output_size = num_rois * pooled_height * pooled_width * channels; 86 | cudaError_t err; 87 | 88 | 89 | ROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 90 | output_size, bottom_data, spatial_scale, height, width, channels, pooled_height, 91 | pooled_width, bottom_rois, top_data, argmax_data); 92 | 93 | err = cudaGetLastError(); 94 | if(cudaSuccess != err) 95 | { 96 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 97 | exit( -1 ); 98 | } 99 | 100 | return 1; 101 | } 102 | 103 | 104 | __global__ void ROIPoolBackward(const int nthreads, const float* top_diff, 105 | const int* argmax_data, const int num_rois, const float spatial_scale, 106 | const int height, const int width, const int channels, 107 | const int pooled_height, const int pooled_width, float* bottom_diff, 108 | const float* bottom_rois) { 109 | CUDA_1D_KERNEL_LOOP(index, nthreads) 110 | { 111 | 112 | // (n, c, ph, pw) is an element in the pooled output 113 | int n = index; 114 | int w = n % width; 115 | n /= width; 116 | int h = n % height; 117 | n /= height; 118 | int c = n % channels; 119 | n /= channels; 120 | 121 | float gradient = 0; 122 | // Accumulate gradient over all ROIs that pooled this element 123 | for (int roi_n = 0; roi_n < num_rois; ++roi_n) 124 | { 125 | const float* offset_bottom_rois = bottom_rois + roi_n * 5; 126 | int roi_batch_ind = offset_bottom_rois[0]; 127 | // Skip if ROI's batch index doesn't match n 128 | if (n != roi_batch_ind) { 129 | continue; 130 | } 131 | 132 | int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); 133 | int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); 134 | int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); 135 | int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); 136 | 137 | // Skip if ROI doesn't include (h, w) 138 | const bool in_roi = (w >= roi_start_w && w <= roi_end_w && 139 | h >= roi_start_h && h <= roi_end_h); 140 | if (!in_roi) { 141 | continue; 142 | } 143 | 144 | int offset = roi_n * pooled_height * pooled_width * channels; 145 | const float* offset_top_diff = top_diff + offset; 146 | const int* offset_argmax_data = argmax_data + offset; 147 | 148 | // Compute feasible set of pooled units that could have pooled 149 | // this bottom unit 150 | 151 | // Force malformed ROIs to be 1x1 152 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 153 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 154 | 155 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 156 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 157 | 158 | int phstart = floor((float)(h - roi_start_h) / bin_size_h); 159 | int phend = ceil((float)(h - roi_start_h + 1) / bin_size_h); 160 | int pwstart = floor((float)(w - roi_start_w) / bin_size_w); 161 | int pwend = ceil((float)(w - roi_start_w + 1) / bin_size_w); 162 | 163 | phstart = fminf(fmaxf(phstart, 0), pooled_height); 164 | phend = fminf(fmaxf(phend, 0), pooled_height); 165 | pwstart = fminf(fmaxf(pwstart, 0), pooled_width); 166 | pwend = fminf(fmaxf(pwend, 0), pooled_width); 167 | 168 | for (int ph = phstart; ph < phend; ++ph) { 169 | for (int pw = pwstart; pw < pwend; ++pw) { 170 | if (offset_argmax_data[(c * pooled_height + ph) * pooled_width + pw] == index) 171 | { 172 | gradient += offset_top_diff[(c * pooled_height + ph) * pooled_width + pw]; 173 | } 174 | } 175 | } 176 | } 177 | bottom_diff[index] = gradient; 178 | } 179 | } 180 | 181 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 182 | const int height, const int width, const int channels, const int pooled_height, 183 | const int pooled_width, const float* bottom_rois, 184 | float* bottom_diff, const int* argmax_data, cudaStream_t stream) 185 | { 186 | const int kThreadsPerBlock = 1024; 187 | const int output_size = batch_size * height * width * channels; 188 | cudaError_t err; 189 | 190 | ROIPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 191 | output_size, top_diff, argmax_data, num_rois, spatial_scale, height, width, channels, pooled_height, 192 | pooled_width, bottom_diff, bottom_rois); 193 | 194 | err = cudaGetLastError(); 195 | if(cudaSuccess != err) 196 | { 197 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 198 | exit( -1 ); 199 | } 200 | 201 | return 1; 202 | } 203 | 204 | 205 | #ifdef __cplusplus 206 | } 207 | #endif 208 | 209 | 210 | -------------------------------------------------------------------------------- /RoI_pooling/src/cuda/roi_pooling_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SirLPS/roi_pooling/1b2b3f953ed10d101e5f4283c56eeebf90b22ef5/RoI_pooling/src/cuda/roi_pooling_kernel.cu.o -------------------------------------------------------------------------------- /RoI_pooling/src/cuda/roi_pooling_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_POOLING_KERNEL 2 | #define _ROI_POOLING_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | int ROIPoolForwardLaucher( 9 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 10 | const int width, const int channels, const int pooled_height, 11 | const int pooled_width, const float* bottom_rois, 12 | float* top_data, int* argmax_data, cudaStream_t stream); 13 | 14 | 15 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 16 | const int height, const int width, const int channels, const int pooled_height, 17 | const int pooled_width, const float* bottom_rois, 18 | float* bottom_diff, const int* argmax_data, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif 25 | 26 | -------------------------------------------------------------------------------- /RoI_pooling/src/roi_pooling.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, 5 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output) 6 | { 7 | // Grab the input tensor 8 | float * data_flat = THFloatTensor_data(features); 9 | float * rois_flat = THFloatTensor_data(rois); 10 | 11 | float * output_flat = THFloatTensor_data(output); 12 | 13 | // Number of ROIs 14 | int num_rois = THFloatTensor_size(rois, 0); 15 | int size_rois = THFloatTensor_size(rois, 1); 16 | // batch size 17 | int batch_size = THFloatTensor_size(features, 0); 18 | if(batch_size != 1) 19 | { 20 | return 0; 21 | } 22 | // data height 23 | int data_height = THFloatTensor_size(features, 1); 24 | // data width 25 | int data_width = THFloatTensor_size(features, 2); 26 | // Number of channels 27 | int num_channels = THFloatTensor_size(features, 3); 28 | 29 | // Set all element of the output tensor to -inf. 30 | THFloatStorage_fill(THFloatTensor_storage(output), -1); 31 | 32 | // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R 33 | int index_roi = 0; 34 | int index_output = 0; 35 | int n; 36 | for (n = 0; n < num_rois; ++n) 37 | { 38 | int roi_batch_ind = rois_flat[index_roi + 0]; 39 | int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale); 40 | int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale); 41 | int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale); 42 | int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale); 43 | // CHECK_GE(roi_batch_ind, 0); 44 | // CHECK_LT(roi_batch_ind, batch_size); 45 | 46 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 47 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 48 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 49 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 50 | 51 | int index_data = roi_batch_ind * data_height * data_width * num_channels; 52 | const int output_area = pooled_width * pooled_height; 53 | 54 | int c, ph, pw; 55 | for (ph = 0; ph < pooled_height; ++ph) 56 | { 57 | for (pw = 0; pw < pooled_width; ++pw) 58 | { 59 | int hstart = (floor((float)(ph) * bin_size_h)); 60 | int wstart = (floor((float)(pw) * bin_size_w)); 61 | int hend = (ceil((float)(ph + 1) * bin_size_h)); 62 | int wend = (ceil((float)(pw + 1) * bin_size_w)); 63 | 64 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height); 65 | hend = fminf(fmaxf(hend + roi_start_h, 0), data_height); 66 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width); 67 | wend = fminf(fmaxf(wend + roi_start_w, 0), data_width); 68 | 69 | const int pool_index = index_output + (ph * pooled_width + pw); 70 | int is_empty = (hend <= hstart) || (wend <= wstart); 71 | if (is_empty) 72 | { 73 | for (c = 0; c < num_channels * output_area; c += output_area) 74 | { 75 | output_flat[pool_index + c] = 0; 76 | } 77 | } 78 | else 79 | { 80 | int h, w, c; 81 | for (h = hstart; h < hend; ++h) 82 | { 83 | for (w = wstart; w < wend; ++w) 84 | { 85 | for (c = 0; c < num_channels; ++c) 86 | { 87 | const int index = (h * data_width + w) * num_channels + c; 88 | if (data_flat[index_data + index] > output_flat[pool_index + c * output_area]) 89 | { 90 | output_flat[pool_index + c * output_area] = data_flat[index_data + index]; 91 | } 92 | } 93 | } 94 | } 95 | } 96 | } 97 | } 98 | 99 | // Increment ROI index 100 | index_roi += size_rois; 101 | index_output += pooled_height * pooled_width * num_channels; 102 | } 103 | return 1; 104 | } -------------------------------------------------------------------------------- /RoI_pooling/src/roi_pooling.h: -------------------------------------------------------------------------------- 1 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, 2 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output); -------------------------------------------------------------------------------- /RoI_pooling/src/roi_pooling_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda/roi_pooling_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale, 8 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax) 9 | { 10 | // Grab the input tensor 11 | float * data_flat = THCudaTensor_data(state, features); 12 | float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | float * output_flat = THCudaTensor_data(state, output); 15 | int * argmax_flat = THCudaIntTensor_data(state, argmax); 16 | 17 | // Number of ROIs 18 | int num_rois = THCudaTensor_size(state, rois, 0); 19 | int size_rois = THCudaTensor_size(state, rois, 1); 20 | if (size_rois != 5) 21 | { 22 | return 0; 23 | } 24 | 25 | // batch size 26 | int batch_size = THCudaTensor_size(state, features, 0); 27 | if (batch_size != 1) 28 | { 29 | return 0; 30 | } 31 | // data height 32 | int data_height = THCudaTensor_size(state, features, 2); 33 | // data width 34 | int data_width = THCudaTensor_size(state, features, 3); 35 | // Number of channels 36 | int num_channels = THCudaTensor_size(state, features, 1); 37 | 38 | cudaStream_t stream = THCState_getCurrentStream(state); 39 | 40 | ROIPoolForwardLaucher( 41 | data_flat, spatial_scale, num_rois, data_height, 42 | data_width, num_channels, pooled_height, 43 | pooled_width, rois_flat, 44 | output_flat, argmax_flat, stream); 45 | 46 | return 1; 47 | } 48 | 49 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, 50 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax) 51 | { 52 | // Grab the input tensor 53 | float * top_grad_flat = THCudaTensor_data(state, top_grad); 54 | float * rois_flat = THCudaTensor_data(state, rois); 55 | 56 | float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 57 | int * argmax_flat = THCudaIntTensor_data(state, argmax); 58 | 59 | // Number of ROIs 60 | int num_rois = THCudaTensor_size(state, rois, 0); 61 | int size_rois = THCudaTensor_size(state, rois, 1); 62 | if (size_rois != 5) 63 | { 64 | return 0; 65 | } 66 | 67 | // batch size 68 | int batch_size = THCudaTensor_size(state, bottom_grad, 0); 69 | if (batch_size != 1) 70 | { 71 | return 0; 72 | } 73 | // data height 74 | int data_height = THCudaTensor_size(state, bottom_grad, 2); 75 | // data width 76 | int data_width = THCudaTensor_size(state, bottom_grad, 3); 77 | // Number of channels 78 | int num_channels = THCudaTensor_size(state, bottom_grad, 1); 79 | 80 | cudaStream_t stream = THCState_getCurrentStream(state); 81 | ROIPoolBackwardLaucher( 82 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 83 | data_width, num_channels, pooled_height, 84 | pooled_width, rois_flat, 85 | bottom_grad_flat, argmax_flat, stream); 86 | 87 | return 1; 88 | } -------------------------------------------------------------------------------- /RoI_pooling/src/roi_pooling_cuda.h: -------------------------------------------------------------------------------- 1 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale, 2 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax); 3 | 4 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, 5 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax); -------------------------------------------------------------------------------- /roi_cupy.py: -------------------------------------------------------------------------------- 1 | kernel_forward = ''' 2 | extern "C" 3 | __global__ void roi_forward(const float* const bottom_data,const float* const bottom_rois, 4 | float* top_data, int* argmax_data, 5 | const double spatial_scale,const int channels,const int height, 6 | const int width, const int pooled_height, 7 | const int pooled_width,const int NN 8 | ){ 9 | 10 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | if(idx>=NN) 12 | return; 13 | const int pw = idx % pooled_width; 14 | const int ph = (idx / pooled_width) % pooled_height; 15 | const int c = (idx / pooled_width / pooled_height) % channels; 16 | int num = idx / pooled_width / pooled_height / channels; 17 | const int roi_batch_ind = bottom_rois[num * 5 + 0]; 18 | const int roi_start_w = round(bottom_rois[num * 5 + 1] * spatial_scale); 19 | const int roi_start_h = round(bottom_rois[num * 5 + 2] * spatial_scale); 20 | const int roi_end_w = round(bottom_rois[num * 5 + 3] * spatial_scale); 21 | const int roi_end_h = round(bottom_rois[num * 5 + 4] * spatial_scale); 22 | // Force malformed ROIs to be 1x1 23 | const int roi_width = max(roi_end_w - roi_start_w + 1, 1); 24 | const int roi_height = max(roi_end_h - roi_start_h + 1, 1); 25 | const float bin_size_h = static_cast(roi_height) 26 | / static_cast(pooled_height); 27 | const float bin_size_w = static_cast(roi_width) 28 | / static_cast(pooled_width); 29 | 30 | int hstart = static_cast(floor(static_cast(ph) 31 | * bin_size_h)); 32 | int wstart = static_cast(floor(static_cast(pw) 33 | * bin_size_w)); 34 | int hend = static_cast(ceil(static_cast(ph + 1) 35 | * bin_size_h)); 36 | int wend = static_cast(ceil(static_cast(pw + 1) 37 | * bin_size_w)); 38 | 39 | // Add roi offsets and clip to input boundaries 40 | hstart = min(max(hstart + roi_start_h, 0), height); 41 | hend = min(max(hend + roi_start_h, 0), height); 42 | wstart = min(max(wstart + roi_start_w, 0), width); 43 | wend = min(max(wend + roi_start_w, 0), width); 44 | bool is_empty = (hend <= hstart) || (wend <= wstart); 45 | 46 | // Define an empty pooling region to be zero 47 | float maxval = is_empty ? 0 : -1E+37; 48 | // If nothing is pooled, argmax=-1 causes nothing to be backprop'd 49 | int maxidx = -1; 50 | const int data_offset = (roi_batch_ind * channels + c) * height * width; 51 | for (int h = hstart; h < hend; ++h) { 52 | for (int w = wstart; w < wend; ++w) { 53 | int bottom_index = h * width + w; 54 | if (bottom_data[data_offset + bottom_index] > maxval) { 55 | maxval = bottom_data[data_offset + bottom_index]; 56 | maxidx = bottom_index; 57 | } 58 | } 59 | } 60 | top_data[idx]=maxval; 61 | argmax_data[idx]=maxidx; 62 | } 63 | ''' 64 | kernel_backward = ''' 65 | extern "C" 66 | __global__ void roi_backward(const float* const top_diff, 67 | const int* const argmax_data,const float* const bottom_rois, 68 | float* bottom_diff, const int num_rois, 69 | const double spatial_scale, int channels, 70 | int height, int width, int pooled_height, 71 | int pooled_width,const int NN) 72 | { 73 | 74 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 75 | ////Importtan >= instead of > 76 | if(idx>=NN) 77 | return; 78 | int w = idx % width; 79 | int h = (idx / width) % height; 80 | int c = (idx/ (width * height)) % channels; 81 | int num = idx / (width * height * channels); 82 | 83 | float gradient = 0; 84 | // Accumulate gradient over all ROIs that pooled this element 85 | for (int roi_n = 0; roi_n < num_rois; ++roi_n) { 86 | // Skip if ROI's batch index doesn't match num 87 | if (num != static_cast(bottom_rois[roi_n * 5])) { 88 | continue; 89 | } 90 | 91 | int roi_start_w = round(bottom_rois[roi_n * 5 + 1] 92 | * spatial_scale); 93 | int roi_start_h = round(bottom_rois[roi_n * 5 + 2] 94 | * spatial_scale); 95 | int roi_end_w = round(bottom_rois[roi_n * 5 + 3] 96 | * spatial_scale); 97 | int roi_end_h = round(bottom_rois[roi_n * 5 + 4] 98 | * spatial_scale); 99 | 100 | // Skip if ROI doesn't include (h, w) 101 | const bool in_roi = (w >= roi_start_w && w <= roi_end_w && 102 | h >= roi_start_h && h <= roi_end_h); 103 | if (!in_roi) { 104 | continue; 105 | } 106 | 107 | int offset = (roi_n * channels + c) * pooled_height 108 | * pooled_width; 109 | 110 | // Compute feasible set of pooled units that could have pooled 111 | // this bottom unit 112 | 113 | // Force malformed ROIs to be 1x1 114 | int roi_width = max(roi_end_w - roi_start_w + 1, 1); 115 | int roi_height = max(roi_end_h - roi_start_h + 1, 1); 116 | 117 | float bin_size_h = static_cast(roi_height) 118 | / static_cast(pooled_height); 119 | float bin_size_w = static_cast(roi_width) 120 | / static_cast(pooled_width); 121 | 122 | int phstart = floor(static_cast(h - roi_start_h) 123 | / bin_size_h); 124 | int phend = ceil(static_cast(h - roi_start_h + 1) 125 | / bin_size_h); 126 | int pwstart = floor(static_cast(w - roi_start_w) 127 | / bin_size_w); 128 | int pwend = ceil(static_cast(w - roi_start_w + 1) 129 | / bin_size_w); 130 | 131 | phstart = min(max(phstart, 0), pooled_height); 132 | phend = min(max(phend, 0), pooled_height); 133 | pwstart = min(max(pwstart, 0), pooled_width); 134 | pwend = min(max(pwend, 0), pooled_width); 135 | for (int ph = phstart; ph < phend; ++ph) { 136 | for (int pw = pwstart; pw < pwend; ++pw) { 137 | int index_ = ph * pooled_width + pw + offset; 138 | if (argmax_data[index_] == (h * width + w)) { 139 | gradient += top_diff[index_]; 140 | } 141 | } 142 | } 143 | } 144 | bottom_diff[idx] = gradient; 145 | } 146 | ''' 147 | -------------------------------------------------------------------------------- /roi_module.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | 4 | import cupy, torch 5 | import cupy as cp 6 | import torch as t 7 | from torch.autograd import Function 8 | 9 | from roi_cupy import kernel_backward, kernel_forward 10 | 11 | Stream = namedtuple('Stream', ['ptr']) 12 | 13 | 14 | @cupy.util.memoize(for_each_device=True) 15 | def load_kernel(kernel_name, code, **kwargs): 16 | cp.cuda.runtime.free(0) 17 | code = Template(code).substitute(**kwargs) 18 | kernel_code = cupy.cuda.compile_with_cache(code) 19 | return kernel_code.get_function(kernel_name) 20 | 21 | 22 | CUDA_NUM_THREADS = 1024 23 | 24 | 25 | def GET_BLOCKS(N, K=CUDA_NUM_THREADS): 26 | return (N + K - 1) // K 27 | 28 | 29 | class RoI(Function): 30 | """ 31 | NOTE:only CUDA-compatible 32 | """ 33 | 34 | def __init__(self, outh, outw, spatial_scale): 35 | self.forward_fn = load_kernel('roi_forward', kernel_forward) 36 | self.backward_fn = load_kernel('roi_backward', kernel_backward) 37 | self.outh, self.outw, self.spatial_scale = outh, outw, spatial_scale 38 | 39 | def forward(self, x, rois): 40 | # NOTE: MAKE SURE input is contiguous too 41 | x = x.contiguous() 42 | rois = rois.contiguous() 43 | self.in_size = B, C, H, W = x.size() 44 | self.N = N = rois.size(0) 45 | output = t.zeros(N, C, self.outh, self.outw).cuda() 46 | self.argmax_data = t.zeros(N, C, self.outh, self.outw).int().cuda() 47 | self.rois = rois 48 | args = [x.data_ptr(), rois.data_ptr(), 49 | output.data_ptr(), 50 | self.argmax_data.data_ptr(), 51 | self.spatial_scale, C, H, W, 52 | self.outh, self.outw, 53 | output.numel()] 54 | stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) 55 | self.forward_fn(args=args, 56 | block=(CUDA_NUM_THREADS, 1, 1), 57 | grid=(GET_BLOCKS(output.numel()), 1, 1), 58 | stream=stream) 59 | return output 60 | 61 | def backward(self, grad_output): 62 | ##NOTE: IMPORTANT CONTIGUOUS 63 | # TODO: input 64 | grad_output = grad_output.contiguous() 65 | B, C, H, W = self.in_size 66 | grad_input = t.zeros(self.in_size).cuda() 67 | stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) 68 | args = [grad_output.data_ptr(), 69 | self.argmax_data.data_ptr(), 70 | self.rois.data_ptr(), 71 | grad_input.data_ptr(), 72 | self.N, self.spatial_scale, C, H, W, self.outh, self.outw, 73 | grad_input.numel()] 74 | self.backward_fn(args=args, 75 | block=(CUDA_NUM_THREADS, 1, 1), 76 | grid=(GET_BLOCKS(grad_input.numel()), 1, 1), 77 | stream=stream 78 | ) 79 | return grad_input, None 80 | 81 | 82 | class RoIPooling2D(t.nn.Module): 83 | 84 | def __init__(self, outh, outw, spatial_scale): 85 | super(RoIPooling2D, self).__init__() 86 | self.RoI = RoI(outh, outw, spatial_scale) 87 | 88 | def forward(self, x, rois): 89 | return self.RoI(x, rois) 90 | 91 | 92 | def test_roi_module(): 93 | ## fake data### 94 | B, N, C, H, W, PH, PW = 2, 8, 4, 32, 32, 7, 7 95 | 96 | bottom_data = t.randn(B, C, H, W).cuda() 97 | bottom_rois = t.randn(N, 5) 98 | bottom_rois[:int(N / 2), 0] = 0 99 | bottom_rois[int(N / 2):, 0] = 1 100 | bottom_rois[:, 1:] = (t.rand(N, 4) * 100).float() 101 | bottom_rois = bottom_rois.cuda() 102 | spatial_scale = 1. / 16 103 | outh, outw = PH, PW 104 | 105 | # pytorch version 106 | module = RoIPooling2D(outh, outw, spatial_scale) 107 | x = t.autograd.Variable(bottom_data, requires_grad=True) 108 | rois = t.autograd.Variable(bottom_rois) 109 | output = module(x, rois) 110 | output.sum().backward() 111 | 112 | def t2c(variable): 113 | npa = variable.data.cpu().numpy() 114 | return cp.array(npa) 115 | 116 | def test_eq(variable, array, info): 117 | cc = cp.asnumpy(array) 118 | neq = (cc != variable.data.cpu().numpy()) 119 | assert neq.sum() == 0, 'test failed: %s' % info 120 | 121 | # chainer version,if you're going to run this 122 | # pip install chainer 123 | import chainer.functions as F 124 | from chainer import Variable 125 | x_cn = Variable(t2c(x)) 126 | 127 | o_cn = F.roi_pooling_2d(x_cn, t2c(rois), outh, outw, spatial_scale) 128 | test_eq(output, o_cn.array, 'forward') 129 | F.sum(o_cn).backward() 130 | test_eq(x.grad, x_cn.grad, 'backward') 131 | print('test pass') 132 | -------------------------------------------------------------------------------- /speed.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as tVariable 5 | from RoI_pooling import RoIPoolFunction 6 | import chainer.functions as FF 7 | import cupy as cp 8 | from chainer import Variable as cVariable 9 | from roi_module import RoIPooling2D 10 | 11 | 12 | def tV2cV(variable): # torch Variable to Chainer Variable 13 | npa = variable.data.cpu().numpy() 14 | 15 | return cVariable(cp.array(npa)) 16 | 17 | 18 | def roi_pooling0(input, rois, size=(7, 7), spatial_scale=1.0): # cffi version 19 | F = RoIPoolFunction(size[0], size[1], spatial_scale) 20 | output = F(input, rois) 21 | if has_backward: 22 | F.backward(output.data.clone()) 23 | return output 24 | 25 | 26 | def roi_pooling1(input, rois, size=(7, 7), spatial_scale=1.0): # cupy version 27 | 28 | module = RoIPooling2D(7,7, spatial_scale=1.0) 29 | output = module(input, rois) 30 | if has_backward: 31 | output.sum().backward() 32 | 33 | return output 34 | 35 | 36 | def roi_pooling2(input, rois, size=(7, 7), spatial_scale=1.0): # chainer version 37 | input, rois = tV2cV(input), tV2cV(rois) 38 | 39 | output = FF.roi_pooling_2d(input, rois, 7, 7, spatial_scale=1.0) 40 | if has_backward: 41 | FF.sum(output).backward() 42 | return output 43 | 44 | 45 | def roi_pooling3(input, rois, size=(7, 7), spatial_scale=1.0): # pytorch version use for loop !!! 46 | assert rois.dim() == 2 47 | assert rois.size(1) == 5 48 | output = [] 49 | rois = rois.data.float() 50 | num_rois = rois.size(0) 51 | 52 | rois[:, 1:].mul_(spatial_scale) 53 | rois = rois.long() 54 | for i in range(num_rois): 55 | roi = rois[i] 56 | im_idx = roi[0] 57 | im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4]+1), roi[1]:(roi[3]+1)] 58 | output.append(F.adaptive_max_pool2d(im, size)) 59 | 60 | output = torch.cat(output, 0) 61 | if has_backward: 62 | # output.backward(output.data.clone()) 63 | output.sum().backward() 64 | return output 65 | 66 | 67 | if __name__ == '__main__': 68 | # batch_size, img_size, num_rois 69 | config = [[1, 50, 300], [8, 8, 100], 70 | [64, 64, 100], [64, 64, 1000], 71 | [256, 256, 100], [256, 256, 1000]] 72 | T = 50 73 | cuda = True 74 | has_backward = True 75 | 76 | print('use_cuda: {}, has_backward: {}'.format(cuda, has_backward)) 77 | for i in range(len(config)): 78 | x = torch.rand((config[i][0], 512, config[i][1], config[i][1])) 79 | rois = torch.rand((config[i][2], 5)) 80 | rois[:, 0] = rois[:, 0] * config[i][0] 81 | rois[:, 1:] = rois[:, 1:] * config[i][1] 82 | for j in range(config[i][2]): 83 | max_, min_ = max(rois[j, 1], rois[j, 3]), min(rois[j, 1], rois[j, 3]) 84 | rois[j, 1], rois[j, 3] = min_, max_ 85 | max_, min_ = max(rois[j, 2], rois[j, 4]), min(rois[j, 2], rois[j, 4]) 86 | rois[j, 2], rois[j, 4] = min_, max_ 87 | rois = torch.floor(rois) 88 | x = tVariable(x, requires_grad=True) 89 | rois = tVariable(rois, requires_grad=False) 90 | 91 | if cuda: 92 | x = x.cuda() 93 | rois = rois.cuda() 94 | 95 | for f, foo in enumerate([roi_pooling0, roi_pooling1, roi_pooling2, roi_pooling3]): 96 | start = time.time() 97 | for t in range(T): 98 | output = foo(x, rois) 99 | print('method{}: {}, batch_size: {}, size: {}, num_rois: {}'.format(f, (time.time() - start) / T, 100 | config[i][0], 101 | config[i][1], 102 | config[i][2])) 103 | print('\n') 104 | -------------------------------------------------------------------------------- /speed1.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from roi_pooling import RoIPoolFunction 6 | 7 | 8 | def roi_pooling1(input, rois, size=(7, 7), spatial_scale=1.0): 9 | F = RoIPoolFunction(size[0], size[1], spatial_scale) 10 | output = F(input, rois) 11 | if has_backward: 12 | F.backward(output.data.clone()) 13 | return output 14 | 15 | 16 | def roi_pooling2(input, rois, size=(7, 7), spatial_scale=1.0): 17 | assert rois.dim() == 2 18 | assert rois.size(1) == 5 19 | output = [] 20 | rois = rois.data.float() 21 | num_rois = rois.size(0) 22 | 23 | rois[:, 1:].mul_(spatial_scale) 24 | rois = rois.long() 25 | for i in range(num_rois): 26 | roi = rois[i] 27 | im_idx = roi[0] 28 | im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4]+1), roi[1]:(roi[3]+1)] 29 | output.append(F.adaptive_max_pool2d(im, size)) 30 | 31 | output = torch.cat(output, 0) 32 | if has_backward: 33 | output.backward(output.data.clone()) 34 | return output 35 | 36 | 37 | if __name__ == '__main__': 38 | # batch_size, img_size, num_rois 39 | config = [[8, 8, 10], [8, 8, 100], 40 | [64, 64, 100], [64, 64, 1000], 41 | [256, 256, 100], [256, 256, 1000]] 42 | T = 50 43 | cuda = True 44 | has_backward = True 45 | 46 | print('use_cuda: {}, has_backward: {}'.format(cuda, has_backward)) 47 | for i in range(len(config)): 48 | x = torch.rand((config[i][0], 3, config[i][1], config[i][1])) 49 | rois = torch.rand((config[i][2], 5)) 50 | rois[:, 0] = rois[:, 0] * config[i][0] 51 | rois[:, 1:] = rois[:, 1:] * config[i][1] 52 | for j in range(config[i][2]): 53 | max_, min_ = max(rois[j, 1], rois[j, 3]), min(rois[j, 1], rois[j, 3]) 54 | rois[j, 1], rois[j, 3] = min_, max_ 55 | max_, min_ = max(rois[j, 2], rois[j, 4]), min(rois[j, 2], rois[j, 4]) 56 | rois[j, 2], rois[j, 4] = min_, max_ 57 | rois = torch.floor(rois) 58 | x = Variable(x, requires_grad=True) 59 | rois = Variable(rois, requires_grad=False) 60 | 61 | if cuda: 62 | x = x.cuda() 63 | rois = rois.cuda() 64 | 65 | for f, foo in enumerate([roi_pooling1, roi_pooling2]): 66 | start = time.time() 67 | for t in range(T): 68 | output = foo(x, rois) 69 | print('method{}: {}, batch_size: {}, size: {}, num_rois: {}'.format(f, (time.time() - start) / T, 70 | config[i][0], 71 | config[i][1], 72 | config[i][2])) 73 | --------------------------------------------------------------------------------