├── README.md ├── png ├── roi_pool.png └── roialign.png ├── roialign ├── CMakeLists.txt ├── main.cpp ├── roi_align.py ├── roi_align_binding.cpp ├── roi_align_cpu.cpp ├── roi_align_cuda.cpp ├── roi_align_kernel.cu ├── setup.py └── temp.h └── roipool ├── CMakeLists.txt ├── main.cpp ├── roi_pool.py ├── roi_pool_binding.cpp ├── roi_pool_cpu.cpp ├── roi_pool_cuda.cpp ├── roi_pool_kernel.cu ├── setup.py └── temp.h /README.md: -------------------------------------------------------------------------------- 1 | # RoI-op-pytorch 2 | C++ extension of RoIPool & RoIAlign (both CPU and GPU) in PyTorch,this code is converted from [caffe2](https://github.com/pytorch/pytorch/tree/master/caffe2/operators) operation. (need pytorch 0.4.0) 3 | 4 | **Warning:**You may change `AT_CHECK` to `AT_ASSERT`(0.4 version using `AT_ASSERT`, and latest version using `AT_CHECK`) 5 | 6 | **Note: ** 7 | 8 | 1. `roi_xxx_cpu.cpp`&`roi_xxx_binding.cpp`:contains the cpu version of forward and backward operation.(Note: `roi_xxx_binding.cpp` is for pybind, you can put this code into `roi_xxx_cpu.cpp` as well) 9 | 2. `roi_xxx_kernel.cu`&`roi_xxx_cuda.cpp`:contains the cuda version of forward and backward operation. 10 | 3. `main.py`&`temp.h`&`CMakeLists.txt`:help you to debug in c++ code, rather than to run `python setup.py install` to debug. (Note: only support cpu version ~ I don't know how to debug `.cu` code :persevere:) 11 | 4. `setup.py`:you can run `python setup.py install` to install this operation as a package (You can find this package in you python site-package) 12 | 5. `roi_xxx.py`:wrap `.cpp` code to pytorch's `Function & Module` ,there is also a small demo testing. 13 | 14 | **Install** 15 | 16 | ```shell 17 | cd roixxx # roipool or roialign 18 | python setup.py install 19 | ``` 20 | 21 | ## RoI Pooling 22 | 23 | The "strategy" of roi-pooling in this implementaion likes the follow picture:(:joy: so bad picture) 24 | 25 | ![oi_poo](png/roi_pool.png) 26 | 27 | Note: (please stand on point view rather than block view) 28 | 29 | 1. scale=0.5 30 | 2. dotted line is the range of "seleted area" (int form in `[left, right)` and `[top, bottom)`) 31 | 32 | ## RoI Align 33 | 34 | ![oialig](png/roialign.png) 35 | 36 | Note: left `sample=1`, `right sample=2` 37 | 38 | 39 | 40 | There are several good resource to explain these two operations: 41 | 42 | - [Region of interest pooling explained](https://blog.deepsense.ai/region-of-interest-pooling-explained/) 43 | - [ROI Align --- chinese](http://blog.leanote.com/post/afanti.deng@gmail.com/b5f4f526490b) 44 | - [ROI Align --- youtube](https://www.youtube.com/watch?v=XGi-Mz3do2s) 45 | 46 | ### Reference 47 | 48 | 1. [caffe2 operator](https://github.com/pytorch/pytorch/tree/a2a28c0ef1d9a433972fe72fa5b0b9b850ccfcaf/caffe2/operators):most of the code is coming from here. 49 | 2. [extension-cpp: tutorial](https://github.com/pytorch/extension-cpp) 50 | 3. [detectorch](https://github.com/ignacio-rocco/detectorch) 51 | 52 | -------------------------------------------------------------------------------- /png/roi_pool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/RoIAlign-RoIPool-pytorch/c21e0c81b6eb5a00c74540568cb9e8a6d8a5455e/png/roi_pool.png -------------------------------------------------------------------------------- /png/roialign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/RoIAlign-RoIPool-pytorch/c21e0c81b6eb5a00c74540568cb9e8a6d8a5455e/png/roialign.png -------------------------------------------------------------------------------- /roialign/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | project(roialign) 3 | 4 | set(CMAKE_CXX_STANDARD 11) 5 | 6 | # change to your pytorch path 7 | set(ATEN_DIR "/your_pytorch_path/pytorch/build/lib.linux-x86_64-3.6/torch/lib") 8 | include_directories(${ATEN_DIR}/include) 9 | 10 | set(SOURCE_FILES main.cpp roi_align_cpu.cpp temp.h) 11 | 12 | add_executable(roialign ${SOURCE_FILES}) 13 | 14 | target_link_libraries(roialign ${ATEN_DIR}/libATen.so) 15 | -------------------------------------------------------------------------------- /roialign/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "temp.h" 4 | 5 | using namespace std; 6 | using namespace at; 7 | 8 | 9 | int main() { 10 | // different sample in this example has the same output, you can 11 | // test it on randn. 12 | auto feat = CPU(kFloat).arange(64).view({1, 1, 8, 8}); 13 | // auto feat = CPU(kFloat).randn({1, 1, 8, 8}).view({1, 1, 8, 8}); 14 | cout << feat << endl; 15 | float roi_data[] = {0, 1.6, 1.6, 9.2, 11.0}; 16 | auto roi = CPU(kFloat).tensorFromBlob(roi_data, {1, 5}); 17 | int64_t pool_h = 2, pool_w = 2, sample=1; 18 | double scale = 0.5; 19 | auto output = roi_align_forward_cpu(feat, roi, pool_h, pool_w, scale, sample); 20 | cout << output << endl; 21 | // auto output2 = roi_align_forward_cpu(feat, roi, pool_h, pool_w, scale, 2); 22 | // cout << output2 << endl; 23 | } -------------------------------------------------------------------------------- /roialign/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from torch.autograd import Function 3 | import roi_align_cpu 4 | import roi_align_cuda 5 | 6 | 7 | class ROIAlignFunction(Function): 8 | @staticmethod 9 | def forward(ctx, feat, rois, pool_h, pool_w, scale, sampling): 10 | ctx.rois = rois 11 | ctx.feat_size = feat.size() 12 | ctx.pool_h = pool_h 13 | ctx.pool_w = pool_w 14 | ctx.scale = scale 15 | ctx.sampling = sampling # sampling number in bin 16 | if feat.is_cuda: 17 | output = roi_align_cuda.forward_cuda(feat, rois, pool_h, pool_w, scale, sampling) 18 | else: 19 | output = roi_align_cpu.forward_cpu(feat, rois, pool_h, pool_w, scale, sampling) 20 | return output 21 | 22 | @staticmethod 23 | def backward(ctx, grad_out): 24 | rois = ctx.rois 25 | feat_size = ctx.feat_size 26 | pool_h = ctx.pool_h 27 | pool_w = ctx.pool_w 28 | scale = ctx.scale 29 | sampling = ctx.sampling 30 | grad_out = grad_out.contiguous() if not grad_out.is_contiguous() else grad_out 31 | if grad_out.is_cuda: 32 | grad_in = roi_align_cuda.backward_cuda(rois, grad_out, feat_size[0], feat_size[1], feat_size[2], 33 | feat_size[3], pool_h, pool_w, scale, sampling) 34 | else: 35 | grad_in = roi_align_cpu.backward_cpu(rois, grad_out, feat_size[0], feat_size[1], feat_size[2], feat_size[3], 36 | pool_h, pool_w, scale, sampling) 37 | # Note: the backward return number is corresponding to the forward parameters number 38 | return grad_in, None, None, None, None, None 39 | 40 | 41 | class ROIAlign(Module): 42 | def __init__(self, pool_h, pool_w, scale, sampling=0): 43 | super(ROIAlign, self).__init__() 44 | self.pool_h, self.pool_w = int(pool_h), int(pool_w) 45 | self.scale = float(scale) 46 | self.sampling = int(sampling) 47 | 48 | # feat: BxCxHxW, rois: Kx5 (batch_idx, xmin, ymin, xmax, ymax) without normalize 49 | def forward(self, feat, rois): 50 | output = ROIAlignFunction.apply(feat, rois, self.pool_h, self.pool_w, self.scale, self.sampling) 51 | return output 52 | 53 | 54 | if __name__ == '__main__': 55 | import torch 56 | 57 | print('------------test on cpu------------') 58 | roi_align = ROIAlign(2, 2, 0.5, 1) 59 | feat = torch.arange(64).view(1, 1, 8, 8) 60 | # Note: first element is batch_idx 61 | rois = torch.Tensor([0, 1.6, 1.6, 9.2, 11.0]).view(-1, 5) 62 | feat.requires_grad = True 63 | out = roi_align(feat, rois) 64 | print(out) 65 | out.sum().backward() 66 | print(feat.grad) 67 | 68 | if torch.cuda.is_available(): 69 | print('------------test on gpu------------') 70 | feat = feat.detach().cuda() 71 | rois = rois.cuda() 72 | feat.requires_grad = True 73 | out = roi_align(feat, rois) 74 | print(out) 75 | temp = out.sum() 76 | temp.backward() 77 | print(feat.grad) 78 | else: 79 | print('You device have not a GPU') 80 | -------------------------------------------------------------------------------- /roialign/roi_align_binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "roi_align_cpu.cpp" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("forward_cpu", &roi_align_forward_cpu, "roi_align_forward_cpu"); 7 | m.def("backward_cpu", &roi_align_backward_cpu, "roi_align_backward_cpu"); 8 | } -------------------------------------------------------------------------------- /roialign/roi_align_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using std::vector; 4 | 5 | template 6 | struct PreCalc { 7 | // left_top, right_top, left_bottom, right_bottom 8 | int pos1, pos2, pos3, pos4; 9 | T w1, w2, w3, w4; 10 | }; 11 | 12 | template 13 | inline void add(const T &val, T *address) { 14 | *address += val; 15 | } 16 | 17 | /* -----------------------------begin for forward--------------------------------- */ 18 | template 19 | void pre_calc_for_bilinear(const int h, const int w, const int pool_h, const int pool_w, int b_grid_h, int b_grid_w, 20 | T start_y, T start_x, T b_size_h, T b_size_w, vector> &pre_calc) { 21 | int idx = 0; 22 | for (int ph = 0; ph < pool_h; ++ph) { 23 | for (int pw = 0; pw < pool_w; ++pw) { 24 | for (int iy = 0; iy < b_grid_h; ++iy) { 25 | const T yy = start_y + ph * b_size_h + static_cast(iy + 0.5f) * b_size_h / static_cast(b_grid_h); 26 | for (int ix = 0; ix < b_grid_w; ++ix) { 27 | const T xx = 28 | start_x + pw * b_size_w + static_cast(ix + 0.5f) * b_size_w / static_cast(b_grid_w); 29 | T x = xx, y = yy; 30 | // situation 1: out of range 31 | if (y < -1.0 || y > h || x < -1.0 || x > w) { 32 | PreCalc pc{0, 0, 0, 0, 0, 0, 0, 0}; 33 | pre_calc[idx] = pc; 34 | idx += 1; 35 | continue; 36 | } 37 | // not exceed 1.0 38 | y = y <= 0 ? 0 : (y >= h - 1 ? h - 1 : y); 39 | x = x <= 0 ? 0 : (x >= w - 1 ? w - 1 : x); 40 | int y_low = (int) y; 41 | int x_low = (int) x; 42 | int y_high = y_low >= h - 1 ? y_low : y_low + 1; 43 | int x_high = x_low >= w - 1 ? x_low : x_low + 1; 44 | T ly = y - y_low, lx = x - x_low; 45 | T hy = 1.0 - ly, hx = 1.0 - lx; 46 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 47 | // in the feature map's position and correspond weights 48 | PreCalc pc; 49 | pc.pos1 = y_low * w + x_low; 50 | pc.pos2 = y_low * w + x_high; 51 | pc.pos3 = y_high * w + x_low; 52 | pc.pos4 = y_high * w + x_high; 53 | pc.w1 = w1, pc.w2 = w2, pc.w3 = w3, pc.w4 = w4; 54 | pre_calc[idx] = pc; 55 | idx += 1; 56 | } // b_grid_w 57 | } // b_grid_h 58 | } // pool_w 59 | } // pool_h 60 | } 61 | 62 | 63 | template 64 | void roi_align_forward(const T *feat, const T *rois, const vector &feat_size, 65 | const vector &rois_size, const T &scale, const int ratio, T *out) { 66 | const int n_rois = rois_size[0], col_rois = rois_size[1], pool_h = rois_size[2], pool_w = rois_size[3]; 67 | const int channel = feat_size[1], h = feat_size[2], w = feat_size[3]; 68 | // #pragma omp parallel for 69 | for (int n = 0; n < n_rois; ++n) { 70 | int idx_n = n * channel * pool_h * pool_w; 71 | // rois data 72 | const T *offset_rois = rois + col_rois * n; 73 | int roi_batch_idx = 0; 74 | if (col_rois == 5) { 75 | roi_batch_idx = offset_rois[0]; 76 | ++offset_rois; 77 | } 78 | // Do not using rounding; this implementation detail is critical 79 | T start_x = offset_rois[0] * scale; 80 | T start_y = offset_rois[1] * scale; 81 | T end_x = offset_rois[2] * scale; 82 | T end_y = offset_rois[3] * scale; 83 | 84 | // Force malformed ROIs to be 1x1 85 | T roi_w = std::max(end_x - start_x, (T) 1.); 86 | T roi_h = std::max(end_y - start_y, (T) 1.); 87 | T bin_size_w = roi_w / static_cast(pool_w); 88 | T bin_size_h = roi_h / static_cast(pool_h); 89 | 90 | // We use roi_bin_grid to sample the grid and mimic integral 91 | int bin_grid_h = (ratio > 0) ? ratio : std::ceil(roi_h / pool_h); 92 | int bin_grid_w = (ratio > 0) ? ratio : std::ceil(roi_w / pool_w); 93 | // We do average (integral) pooling inside a bin 94 | const T count = bin_grid_h * bin_grid_w; 95 | // get each bin's corresponding position and weights 96 | std::vector> pre_calc(count * pool_h * pool_w); 97 | pre_calc_for_bilinear(h, w, pool_h, pool_w, bin_grid_h, bin_grid_w, start_y, start_x, bin_size_h, bin_size_w, 98 | pre_calc); 99 | // map to feature map 100 | for (int c = 0; c < channel; ++c) { 101 | int idx_nc = idx_n + c * pool_w * pool_h; 102 | const T *offset_feat = feat + (roi_batch_idx * channel + c) * h * w; 103 | int pre_calc_idx = 0; 104 | for (int ph = 0; ph < pool_h; ++ph) { 105 | for (int pw = 0; pw < pool_w; ++pw) { 106 | int idx = idx_nc + ph * pool_w + pw; 107 | T output_val = 0.; 108 | for (int iy = 0; iy < bin_grid_h; ++iy) { 109 | for (int ix = 0; ix < bin_grid_w; ++ix) { 110 | PreCalc pc = pre_calc[pre_calc_idx]; 111 | output_val += pc.w1 * offset_feat[pc.pos1] + pc.w2 * offset_feat[pc.pos2] + 112 | pc.w3 * offset_feat[pc.pos3] + pc.w4 * offset_feat[pc.pos4]; 113 | pre_calc_idx += 1; 114 | } 115 | } 116 | output_val /= count; 117 | out[idx] = output_val; 118 | } // for pw 119 | } // for ph 120 | } // for c 121 | } // for rois_n 122 | } 123 | 124 | 125 | // input: BxCxHxW; rois: Kx5 126 | at::Tensor roi_align_forward_cpu(const at::Tensor &feat, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 127 | double scale, int64_t sample) { 128 | AT_CHECK(feat.ndimension() == 4, "Feature should be BxCxHxW forms"); 129 | AT_CHECK(feat.is_contiguous(), "Feature should be contiguous"); 130 | AT_CHECK(rois.ndimension() == 2, "ROI Proposals should be Kx5 forms"); 131 | AT_CHECK(rois.size(1) == 5, "ROI proposals should be Kx5 forms"); 132 | AT_CHECK(rois.is_contiguous(), "ROI proposals should be contiguous."); 133 | 134 | const vector rois_size = {rois.size(0), rois.size(1), pool_h, pool_w}; 135 | const vector feat_size = {feat.size(0), feat.size(1), feat.size(2), feat.size(3)}; 136 | 137 | auto output = feat.type().tensor({rois_size[0], feat_size[1], pool_h, pool_w}); 138 | roi_align_forward(feat.data(), rois.data(), feat_size, rois_size, static_cast(scale), sample, 139 | output.data()); 140 | return output; 141 | } 142 | /*------------------------------end of forward-----------------------------*/ 143 | 144 | /*------------------------------begin for backward-----------------------------*/ 145 | template 146 | void bilinear_interpolate_gradient(const int h, const int w, T y, T x, PreCalc &pc) { 147 | if (y < -1.0 || y > h || x < -1.0 || x > w) { 148 | pc = {-1, -1, -1, -1, 0., 0., 0., 0.}; 149 | return; 150 | } 151 | // not exceed 1.0 152 | y = y <= 0 ? 0 : (y >= h - 1 ? h - 1 : y); 153 | x = x <= 0 ? 0 : (x >= w - 1 ? w - 1 : x); 154 | int y_low = (int) y; 155 | int x_low = (int) x; 156 | int y_high = y_low >= h - 1 ? y_low : y_low + 1; 157 | int x_high = x_low >= w - 1 ? x_low : x_low + 1; 158 | pc.pos1 = y_low * w + x_low; 159 | pc.pos2 = y_low * w + x_high; 160 | pc.pos3 = y_high * w + x_low; 161 | pc.pos4 = y_high * w + x_high; 162 | T ly = y - y_low, lx = x - x_low; 163 | T hy = 1.0 - ly, hx = 1.0 - lx; 164 | pc.w1 = hy * hx, pc.w2 = hy * lx, pc.w3 = ly * hx, pc.w4 = ly * lx; 165 | } 166 | 167 | 168 | template 169 | void roi_align_backward(int total, const T *rois, T *grad_out, const T &scale, const vector feat_size, 170 | const int pool_h, const int pool_w, const int rois_col, const int sample, T *grad_in) { 171 | // total=nxcxphxpw 172 | auto channel = feat_size[0], h = feat_size[1], w = feat_size[2]; 173 | for (int idx = 0; idx < total; ++idx) { 174 | int pw = idx % pool_w; 175 | int ph = (idx / pool_w) % pool_h; 176 | int c = (idx / pool_h / pool_w) % channel; 177 | int n = idx / pool_h / pool_w / channel; 178 | 179 | const T *offset_rois = rois + n * rois_col; 180 | int roi_batch_idx = 0; 181 | if (rois_col == 5) { 182 | roi_batch_idx = offset_rois[0]; 183 | ++offset_rois; 184 | } 185 | // Do not using rounding; this implementation detail is critical 186 | T start_x = offset_rois[0] * scale; 187 | T start_y = offset_rois[1] * scale; 188 | T end_x = offset_rois[2] * scale; 189 | T end_y = offset_rois[3] * scale; 190 | 191 | // Force malformed ROIs to be 1x1 192 | T roi_w = std::max(end_x - start_x, (T) 1.0); 193 | T roi_h = std::max(end_y - start_y, (T) 1.0); 194 | T b_size_h = roi_h / static_cast(pool_h); 195 | T b_size_w = roi_w / static_cast(pool_w); 196 | 197 | T *offset_grad_in = grad_in + (roi_batch_idx * channel + c) * h * w; 198 | T *offset_grad_out = grad_out + (n * channel + c) * pool_h * pool_w; 199 | T grad_out_this_bin = offset_grad_out[ph * pool_w + pw]; 200 | 201 | // We use roi_bin_grid to sample the grid and mimic integral 202 | int roi_bin_grid_h = (sample > 0) ? sample : std::ceil(roi_h / pool_h); 203 | int roi_bin_grid_w = (sample > 0) ? sample : std::ceil(roi_w / pool_w); 204 | // We do average (integral) pooling inside a bin 205 | const int count = roi_bin_grid_h * roi_bin_grid_w; 206 | PreCalc pc; 207 | for (int iy = 0; iy < roi_bin_grid_h; iy++) { 208 | const T y = start_y + ph * b_size_h + 209 | static_cast(iy + .5f) * b_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 210 | for (int ix = 0; ix < roi_bin_grid_w; ix++) { 211 | const T x = start_x + pw * b_size_w + 212 | static_cast(ix + .5f) * b_size_w / static_cast(roi_bin_grid_w); 213 | bilinear_interpolate_gradient(h, w, y, x, pc); 214 | T g1 = grad_out_this_bin * pc.w1 / count; 215 | T g2 = grad_out_this_bin * pc.w2 / count; 216 | T g3 = grad_out_this_bin * pc.w3 / count; 217 | T g4 = grad_out_this_bin * pc.w4 / count; 218 | // update grad_out 219 | if (pc.pos1 >= 0 && pc.pos2 >= 0 && pc.pos3 >= 0 && pc.pos4 >= 0) { 220 | add(g1, offset_grad_in + pc.pos1); 221 | add(g2, offset_grad_in + pc.pos2); 222 | add(g3, offset_grad_in + pc.pos3); 223 | add(g4, offset_grad_in + pc.pos4); 224 | } 225 | } // for ix 226 | } // for iy 227 | } // for 228 | } 229 | 230 | 231 | at::Tensor 232 | roi_align_backward_cpu(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 233 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, double scale, int64_t sample) { 234 | AT_CHECK(rois.ndimension() == 2 && rois.size(1) == 5, "ROI Proposals should be Kx5 forms") 235 | AT_CHECK(rois.is_contiguous(), "ROI proposals should be contiguous.") 236 | auto rois_col = rois.size(1); 237 | auto grad_in = rois.type().tensor({b_size, channel, h, w}); 238 | grad_in.zero_(); 239 | std::cout << grad_in << std::endl; 240 | roi_align_backward(grad_out.numel(), rois.data(), grad_out.data(), static_cast(scale), 241 | {channel, h, w}, pool_h, pool_w, rois_col, sample, grad_in.data()); 242 | return grad_in; 243 | } 244 | -------------------------------------------------------------------------------- /roialign/roi_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | at::Tensor roi_align_forward_cuda(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 5 | double scale, int64_t sampling); 6 | 7 | at::Tensor roi_align_backward_cuda(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 8 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, double scale, 9 | int64_t sampling); 10 | 11 | 12 | // C++ interface 13 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 15 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 16 | 17 | at::Tensor roi_align_forward(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 18 | double scale, int64_t sampling) { 19 | CHECK_INPUT(input); 20 | CHECK_INPUT(rois); 21 | return roi_align_forward_cuda(input, rois, pool_h, pool_w, scale, sampling); 22 | } 23 | 24 | at::Tensor roi_align_backward(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 25 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, double scale, int64_t sampling) { 26 | CHECK_INPUT(grad_out); 27 | CHECK_INPUT(rois); 28 | return roi_align_backward_cuda(rois, grad_out, b_size, channel, h, w, pool_h, pool_w, scale, sampling); 29 | } 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("forward_cuda", &roi_align_forward, "roi_align_forward_cuda"); 33 | m.def("backward_cuda", &roi_align_backward, "roi_align_backward_cuda"); 34 | } -------------------------------------------------------------------------------- /roialign/roi_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | template 7 | __device__ __forceinline__ 8 | 9 | T fmin(T a, T b) { 10 | return a > b ? b : a; 11 | } 12 | 13 | template 14 | __device__ __forceinline__ 15 | 16 | T fmax(T a, T b) { 17 | return a < b ? b : a; 18 | } 19 | 20 | template 21 | __device__ __forceinline__ 22 | 23 | T gpu_atomic_add(const T val, T *address) { 24 | return atomicAdd(address, val); 25 | } 26 | 27 | /* ------------------------------begin of the forward--------------------------- */ 28 | template 29 | __device__ T 30 | 31 | bilinear_interpolate(const T *input, const int h, const int w, T y, T x) { 32 | // deal with cases that inverse elements are out of feature map boundary 33 | if (y < -1.0 || y > h || x < -1.0 || x > w) { 34 | return 0; 35 | } 36 | 37 | y = y <= 0 ? 0 : (y >= h - 1 ? h - 1 : y); 38 | x = x <= 0 ? 0 : (x >= w - 1 ? w - 1 : x); 39 | 40 | int y_low = (int) y; 41 | int x_low = (int) x; 42 | int y_high = y_low >= h - 1 ? y_low : y_low + 1; 43 | int x_high = x_low >= w - 1 ? x_low : x_low + 1; 44 | 45 | T ly = y - y_low; 46 | T lx = x - x_low; 47 | T hy = 1. - ly, hx = 1. - lx; 48 | // do bilinear interpolation 49 | T v1 = input[y_low * w + x_low]; 50 | T v2 = input[y_low * w + x_high]; 51 | T v3 = input[y_high * w + x_low]; 52 | T v4 = input[y_high * w + x_high]; 53 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 54 | 55 | T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 56 | return val; 57 | } 58 | 59 | 60 | template 61 | __global__ void 62 | roi_align_forward_kernel(const int total, const T *input, const T *rois, const T scale, const int channel, const int h, 63 | const int w, const int pool_h, const int pool_w, const int sampling, T *output) { 64 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { 65 | int pw = idx % pool_w; 66 | int ph = (idx / pool_w) % pool_h; 67 | int c = (idx / pool_h / pool_w) % channel; 68 | int n = idx / pool_h / pool_w / channel; 69 | 70 | const T *offset_rois = rois + n * 5; 71 | int roi_batch_idx = offset_rois[0]; 72 | 73 | // Do not using rounding; this implementation detail is critical 74 | T start_x = offset_rois[1] * scale; 75 | T start_y = offset_rois[2] * scale; 76 | T end_x = offset_rois[3] * scale; 77 | T end_y = offset_rois[4] * scale; 78 | 79 | // Force malformed ROIs to be 1x1 80 | T roi_w = fmax(end_x - start_x, (T) 1.); 81 | T roi_h = fmax(end_y - start_y, (T) 1.); 82 | T bin_size_h = roi_h / static_cast(pool_h); 83 | T bin_size_w = roi_w / static_cast(pool_w); 84 | 85 | const T *offset_input = input + (roi_batch_idx * channel * c) * h * w; 86 | 87 | // We use roi_bin_grid to sample the grid and mimic integral 88 | int bin_grid_h = sampling > 0 ? sampling : ceilf(roi_h / pool_h); 89 | int bin_grid_w = sampling > 0 ? sampling : ceilf(roi_w / pool_w); 90 | // We do average (integral) pooling inside a bin 91 | const T count = bin_grid_h * bin_grid_w; 92 | 93 | T output_val = 0.; 94 | for (int iy = 0; iy < bin_grid_h; ++iy) { 95 | T y = start_y + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(bin_grid_h); 96 | for (int ix = 0; ix < bin_grid_w; ++ix) { 97 | T x = start_x + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / 98 | static_cast(bin_grid_w); 99 | T val = bilinear_interpolate(offset_input, h, w, y, x); 100 | output_val += val; 101 | } 102 | } 103 | output[idx] = output_val /= count; 104 | } 105 | } 106 | 107 | at::Tensor roi_align_forward_cuda(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 108 | double scale, int64_t sampling) { 109 | AT_CHECK(input.ndimension() == 4 && input.is_contiguous(), "Input features should be BxCxHxW and contiguous"); 110 | AT_CHECK(rois.ndimension() == 2 && rois.size(1) == 5, "ROIs should be Kx5 forms"); 111 | AT_CHECK(rois.is_contiguous(), "ROIs should be contiguous"); 112 | 113 | auto rois_num = rois.size(0); 114 | auto channel = input.size(1), h = input.size(2), w = input.size(3); 115 | 116 | auto output = input.type().tensor({rois_num, channel, pool_h, pool_w}); 117 | 118 | int64_t total = output.numel(); 119 | const int threads = 1024; 120 | const int64_t blocks = (total + threads - 1) / threads > 65535 ? 65535 : (total + threads - 1) / threads; 121 | 122 | roi_align_forward_kernel << < blocks, threads >> > (output.numel(), input.data(), rois.data(), 123 | static_cast(scale), channel, h, w, pool_h, pool_w, sampling, output.data()); 124 | 125 | AT_CHECK(cudaGetLastError() == cudaSuccess, "roi_align_forward_kernel failed"); 126 | return output; 127 | } 128 | /* ------------------------------end of the forward--------------------------- */ 129 | 130 | /* ------------------------------begin of the backward--------------------------- */ 131 | template 132 | __device__ void bilinear_interpolate_gradient(const int h, const int w, T y, T x, T &w1, T &w2, T &w3, T &w4, 133 | int &pos1, int &pos2, int &pos3, int &pos4) { 134 | // deal with cases that inverse elements are out of feature map boundary 135 | if (y < -1.0 || y > h || x < -1.0 || x > w) { 136 | w1 = w2 = w3 = w4 = 0.; 137 | pos1 = pos2 = pos3 = pos4 = -1; 138 | return; 139 | } 140 | 141 | y = y <= 0 ? 0 : (y >= h - 1 ? h - 1 : y); 142 | x = x <= 0 ? 0 : (x >= w - 1 ? w - 1 : x); 143 | 144 | int y_low = (int) y; 145 | int x_low = (int) x; 146 | int y_high = y_low >= h - 1 ? y_low : y_low + 1; 147 | int x_high = x_low >= w - 1 ? x_low : x_low + 1; 148 | 149 | pos1 = y_low * w + x_low; 150 | pos2 = y_low * w + x_high; 151 | pos3 = y_high * w + x_low; 152 | pos4 = y_high * w + x_high; 153 | 154 | T ly = y - y_low; 155 | T lx = x - x_low; 156 | T hy = 1. - ly, hx = 1. - lx; 157 | 158 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 159 | } 160 | 161 | 162 | template 163 | __global__ void roi_align_backward_kernel(const int total, const T *grad_out, const int rois_num, 164 | const T scale, const int channels, const int h, const int w, 165 | const int pool_h, const int pool_w, const int sampling, T *grad_in, 166 | const T *rois, int rois_col) { 167 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { 168 | // (n, c, ph, pw) is an element in the pooled output 169 | int pw = idx % pool_w; 170 | int ph = (idx / pool_w) % pool_h; 171 | int c = (idx / pool_w / pool_h) % channels; 172 | int n = idx / pool_w / pool_h / channels; 173 | 174 | const T *offset_rois = rois + n * 5; 175 | int roi_batch_idx = offset_rois[0]; 176 | 177 | // Do not using rounding; this implementation detail is critical 178 | T start_x = offset_rois[1] * scale; 179 | T start_y = offset_rois[2] * scale; 180 | T end_x = offset_rois[3] * scale; 181 | T end_y = offset_rois[4] * scale; 182 | 183 | 184 | // Force malformed ROIs to be 1x1 185 | T roi_w = fmax(end_x - start_x, (T) 1.); 186 | T roi_h = fmax(end_y - start_y, (T) 1.); 187 | T bin_size_h = roi_h / static_cast(pool_h); 188 | T bin_size_w = roi_w / static_cast(pool_w); 189 | 190 | T *offset_grad_in = grad_in + (roi_batch_idx * channels + c) * h * w; 191 | 192 | const T *offset_grad_out = grad_out + (n * channels + c) * pool_h * pool_w; 193 | const T grad_out_this_bin = offset_grad_out[ph * pool_w + pw]; 194 | 195 | // We use roi_bin_grid to sample the grid and mimic integral 196 | int roi_bin_grid_h = (sampling > 0) ? sampling : ceilf(roi_h / pool_h); // e.g., = 2 197 | int roi_bin_grid_w = (sampling > 0) ? sampling : ceilf(roi_w / pool_w); 198 | 199 | // We do average (integral) pooling inside a bin 200 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 201 | // PreCalc data type 202 | T w1, w2, w3, w4; 203 | int pos1, pos2, pos3, pos4; 204 | for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 205 | { 206 | const T y = start_y + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / 207 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 208 | for (int ix = 0; ix < roi_bin_grid_w; ix++) { 209 | const T x = start_x + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / 210 | static_cast(roi_bin_grid_w); 211 | bilinear_interpolate_gradient(h, w, y, x, w1, w2, w3, w4, pos1, pos2, pos3, pos4); 212 | 213 | T g1 = grad_out_this_bin * w1 / count; 214 | T g2 = grad_out_this_bin * w2 / count; 215 | T g3 = grad_out_this_bin * w3 / count; 216 | T g4 = grad_out_this_bin * w4 / count; 217 | 218 | if (pos1 >= 0 && pos2 >= 0 && pos3 >= 0 && pos4 >= 0) { 219 | gpu_atomic_add(static_cast(g1), offset_grad_in + pos1); 220 | gpu_atomic_add(static_cast(g2), offset_grad_in + pos2); 221 | gpu_atomic_add(static_cast(g3), offset_grad_in + pos3); 222 | gpu_atomic_add(static_cast(g4), offset_grad_in + pos4); 223 | } // if 224 | } // ix 225 | } // iy 226 | } // CUDA_1D_KERNEL_LOOP 227 | } // RoIAlignBackward 228 | 229 | at::Tensor roi_align_backward_cuda(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 230 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, double scale, 231 | int64_t sampling) { 232 | AT_CHECK(rois.ndimension() == 2 && rois.size(1) == 5, "ROIs should be Kx5 forms"); 233 | AT_CHECK(rois.is_contiguous(), "ROIs should be contiguous"); 234 | 235 | auto rois_num = rois.size(0), rois_col = rois.size(1); 236 | 237 | auto grad_in = rois.type().tensor({b_size, channel, h, w}); 238 | grad_in.zero_(); 239 | 240 | int64_t total = grad_out.numel(); 241 | const int threads = 1024; 242 | const int64_t blocks = (total + threads - 1) / threads > 65535 ? 65535 : (total + threads - 1) / threads; 243 | 244 | roi_align_backward_kernel << < blocks, threads >> > (grad_out.numel(), grad_out.data(), rois_num, 245 | static_cast(scale), channel, h, w, pool_h, pool_w, sampling, grad_in.data(), 246 | rois.data(), rois_col); 247 | 248 | AT_CHECK(cudaGetLastError() == cudaSuccess, "roi_align_forward_kernel failed"); 249 | return grad_in; 250 | } -------------------------------------------------------------------------------- /roialign/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 3 | 4 | setup( 5 | name='roi_align_cpp', 6 | ext_modules=[ 7 | CppExtension('roi_align_cpu', [ 8 | 'roi_align_binding.cpp' 9 | ]), 10 | CUDAExtension('roi_align_cuda', [ 11 | 'roi_align_cuda.cpp', 12 | 'roi_align_kernel.cu', 13 | ]), 14 | ], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) 18 | -------------------------------------------------------------------------------- /roialign/temp.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_TEMP_H 2 | #define ROI_TEMP_H 3 | 4 | #include 5 | 6 | at::Tensor roi_align_forward_cpu(const at::Tensor &feat, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 7 | double scale, int64_t sample); 8 | 9 | at::Tensor 10 | roi_align_backward_cpu(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 11 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, double scale, int64_t sample); 12 | 13 | #endif //ROI_TEMP_H 14 | -------------------------------------------------------------------------------- /roipool/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | project(roipool) 3 | 4 | set(CMAKE_CXX_STANDARD 11) 5 | 6 | # You should change it to your build pytorch lib 7 | set(ATEN_DIR "/your_pytorch_path/pytorch/build/lib.linux-x86_64-3.6/torch/lib") 8 | include_directories(${ATEN_DIR}/include) 9 | 10 | set(SOURCE_FILES main.cpp roi_pool_cpu.cpp temp.h) 11 | 12 | add_executable(roipool ${SOURCE_FILES}) 13 | 14 | target_link_libraries(roipool ${ATEN_DIR}/libATen.so) 15 | -------------------------------------------------------------------------------- /roipool/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "temp.h" 4 | 5 | using namespace std; 6 | using namespace at; 7 | 8 | int main() { 9 | auto feat = CPU(kFloat).arange(64).view({1, 1, 8, 8}); 10 | cout << feat << endl; 11 | float roi_data[] = {0, 1.6, 1.6, 9.2, 11.0}; 12 | auto roi = CPU(kFloat).tensorFromBlob(roi_data, {1, 5}); 13 | auto memory = CPU(kInt).zeros({0}); 14 | int64_t pool_h = 2, pool_w = 2; 15 | double scale = 0.5; 16 | auto output = roi_pool_forward_cpu(feat, roi, pool_h, pool_w, scale, memory); 17 | cout << output << endl; 18 | } -------------------------------------------------------------------------------- /roipool/roi_pool.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from torch.autograd import Function 3 | import roi_pool_cpu 4 | import roi_pool_cuda 5 | 6 | 7 | class ROIPoolFunction(Function): 8 | @staticmethod 9 | def forward(ctx, feat, rois, pool_h, pool_w, scale, train): 10 | ctx.rois = rois 11 | ctx.feat_size = feat.size() 12 | ctx.pool_h = pool_h 13 | ctx.pool_w = pool_w 14 | if train: 15 | ctx.memory = torch.zeros((rois.size(0), feat.size(1), pool_h, pool_w), dtype=torch.int) 16 | else: 17 | ctx.memory = torch.zeros(0) 18 | if feat.is_cuda: 19 | ctx.memory = ctx.memory.cuda() 20 | output = roi_pool_cuda.forward_cuda(feat, rois, pool_h, pool_w, scale, ctx.memory) 21 | else: 22 | output = roi_pool_cpu.forward_cpu(feat, rois, pool_h, pool_w, scale, ctx.memory) 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad_out): 27 | rois = ctx.rois 28 | feat_size = ctx.feat_size 29 | pool_h = ctx.pool_h 30 | pool_w = ctx.pool_w 31 | memory = ctx.memory 32 | grad_out = grad_out.contiguous() if not grad_out.is_contiguous() else grad_out 33 | if grad_out.is_cuda: 34 | grad_in = roi_pool_cuda.backward_cuda(rois, grad_out, feat_size[0], feat_size[1], feat_size[2], 35 | feat_size[3], pool_h, pool_w, memory) 36 | else: 37 | grad_in = roi_pool_cpu.backward_cpu(rois, grad_out, feat_size[0], feat_size[1], feat_size[2], 38 | feat_size[3], pool_h, pool_w, memory) 39 | # Note: the backward return number is corresponding to the ctx variable 40 | return grad_in, None, None, None, None, None 41 | 42 | 43 | class ROIPool(Module): 44 | def __init__(self, pool_h, pool_w, scale): 45 | super(ROIPool, self).__init__() 46 | self.pool_h, self.pool_w = int(pool_h), int(pool_w) 47 | self.scale = float(scale) 48 | 49 | # feat: BxCxHxW, rois: Kx5 (batch_idx, xmin, ymin, xmax, ymax) without normalize 50 | def forward(self, feat, rois): 51 | output = ROIPoolFunction.apply(feat, rois, self.pool_h, self.pool_w, self.scale, self.training) 52 | return output 53 | 54 | 55 | if __name__ == '__main__': 56 | import torch 57 | 58 | print('------------test on cpu------------') 59 | roi_pool = ROIPool(2, 2, 0.5) 60 | feat = torch.arange(64).view(1, 1, 8, 8) 61 | # Note: first element is batch_idx 62 | rois = torch.Tensor([0, 1.6, 1.6, 9.2, 11.0]).view(-1, 5) 63 | feat.requires_grad = True 64 | out = roi_pool(feat, rois) 65 | print(out) 66 | out.sum().backward() 67 | print(feat.grad) 68 | 69 | if torch.cuda.is_available(): 70 | print('------------test on gpu------------') 71 | feat = feat.detach().cuda() 72 | rois = rois.cuda() 73 | feat.requires_grad = True 74 | out = roi_pool(feat, rois) 75 | print(out) 76 | temp = out.sum() 77 | temp.backward() 78 | print(feat.grad) 79 | else: 80 | print('You device have not a GPU') 81 | -------------------------------------------------------------------------------- /roipool/roi_pool_binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "roi_pool_cpu.cpp" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("forward_cpu", &roi_pool_forward_cpu, "roi_pool_forward_cpu"); 7 | m.def("backward_cpu", &roi_pool_backward_cpu, "roi_pool_backward_cpu"); 8 | } -------------------------------------------------------------------------------- /roipool/roi_pool_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using std::vector; 5 | using std::max; 6 | using std::min; 7 | 8 | /* -----------------------------begin of the forward--------------------------------- */ 9 | template 10 | void roi_pool_forward(const T *input, const T *rois, vector in_size, vector rois_size, T scale, 11 | T *output, int *memory) { 12 | int rois_num = rois_size[0], rois_col = rois_size[1], pool_h = rois_size[2], pool_w = rois_size[3]; 13 | int channels = in_size[1], height = in_size[2], width = in_size[3]; 14 | int chw = channels * height * width, chw_p = channels * pool_h * pool_w; 15 | int *memory_data; 16 | for (int n = 0; n < rois_num; ++n) { 17 | int roi_batch_id = rois[0]; 18 | int roi_start_w = round(rois[1] * scale); 19 | int roi_start_h = round(rois[2] * scale); 20 | int roi_end_w = round(rois[3] * scale); 21 | int roi_end_h = round(rois[4] * scale); 22 | // Force malformed ROIs to be 1x1 23 | int roi_height = max(roi_end_h - roi_start_h + 1, 1); 24 | int roi_width = max(roi_end_w - roi_start_w + 1, 1); 25 | 26 | const T bin_size_h = static_cast(roi_height) / static_cast(pool_h); 27 | const T bin_size_w = static_cast(roi_width) / static_cast(pool_w); 28 | 29 | const T *input_data = input + roi_batch_id * chw; 30 | if (memory) 31 | memory_data = memory + n * chw_p; 32 | 33 | for (int c = 0; c < channels; ++c) { 34 | for (int ph = 0; ph < pool_h; ++ph) { 35 | for (int pw = 0; pw < pool_w; ++pw) { 36 | // Compute pooling region for this output unit: 37 | int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); 38 | int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); 39 | int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); 40 | int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); 41 | 42 | // Add roi offsets and clip to input boundaries 43 | hstart = min(max(hstart + roi_start_h, 0), height); 44 | hend = min(max(hend + roi_start_h, 0), height); 45 | wstart = min(max(wstart + roi_start_w, 0), width); 46 | wend = min(max(wend + roi_start_w, 0), width); 47 | 48 | const int pool_index = ph * pool_w + pw; 49 | // Define an empty pooling region to be zero 50 | bool is_empty = (hend <= hstart) || (wend <= wstart); 51 | output[pool_index] = is_empty ? 0 : -FLT_MAX; 52 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd 53 | if (memory) 54 | memory_data[pool_index] = -1; 55 | 56 | for (int hi = hstart; hi < hend; ++hi) { 57 | for (int wi = wstart; wi < wend; ++wi) { 58 | const int index = hi * width + wi; 59 | if (input_data[index] > output[pool_index]) { 60 | output[pool_index] = input_data[index]; 61 | if (memory) 62 | memory_data[pool_index] = index; 63 | } 64 | } 65 | } 66 | } 67 | } 68 | // Increment all data pointers by one channel 69 | input_data += height * width; 70 | output += pool_h * pool_w; 71 | if (memory) memory_data += pool_h * pool_w; 72 | } 73 | rois += rois_col; 74 | } 75 | } 76 | 77 | at::Tensor roi_pool_forward_cpu(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 78 | double scale, at::Tensor &memory) { 79 | AT_CHECK(input.ndimension() == 4, "Feature should be BxCxHxW forms"); 80 | AT_CHECK(input.is_contiguous(), "Feature should be contiguous"); 81 | AT_CHECK(rois.ndimension() == 2, "ROI Proposals should be Kx5 forms"); 82 | AT_CHECK(rois.size(1) == 5, "ROI proposals should be Kx5 forms"); 83 | AT_CHECK(rois.is_contiguous(), "ROI proposals should be contiguous."); 84 | 85 | const vector rois_size = {rois.size(0), rois.size(1), pool_h, pool_w}; 86 | const vector input_size = {input.size(0), input.size(1), input.size(2), input.size(3)}; 87 | 88 | auto output = input.type().tensor({rois_size[0], input_size[1], pool_h, pool_w}); 89 | if (memory.data()) 90 | memory.zero_(); 91 | 92 | roi_pool_forward(input.data(), rois.data(), input_size, rois_size, static_cast(scale), 93 | output.data(), memory.data()); 94 | return output; 95 | } 96 | /* -----------------------------end of the forward--------------------------------- */ 97 | 98 | /* -----------------------------begin of the backward--------------------------------- */ 99 | template 100 | void roi_pool_backward(const int total, const T *grad_out, const T *rois, const int channels, const int h, const int w, 101 | const int pool_h, const int pool_w, T *grad_in, const int *memory) { 102 | for (int idx = 0; idx < total; ++idx) { 103 | int pw = idx % pool_w; 104 | int ph = (idx / pool_w) % pool_h; 105 | int c = (idx / pool_w / pool_h) % channels; 106 | int n = idx / pool_w / pool_h / channels; 107 | 108 | const T *offset_rois = rois + n * 5; 109 | int roi_batch_idx = offset_rois[0]; 110 | 111 | // offset of index 112 | int grad_in_offset = (roi_batch_idx * channels + c) * h * w; 113 | int grad_out_offset = (n * channels + c) * pool_h * pool_w; 114 | 115 | const T *offset_grad_out = grad_out + grad_out_offset; 116 | T *offset_grad_in = grad_in + grad_in_offset; 117 | //const int *offset_memory = memory + grad_in_offset; 118 | const int *offset_memory = memory + grad_out_offset 119 | 120 | int argmax = offset_memory[ph * pool_w + pw]; 121 | if (argmax != -1) 122 | offset_grad_in[argmax] += offset_grad_out[ph * pool_w + pw]; 123 | } 124 | } 125 | 126 | 127 | at::Tensor roi_pool_backward_cpu(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 128 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, at::Tensor &memory) { 129 | AT_CHECK(grad_out.ndimension() == 4, "Feature should be BxCxHxW forms"); 130 | AT_CHECK(grad_out.is_contiguous(), "Feature should be contiguous"); 131 | AT_CHECK(rois.ndimension() == 2, "ROI Proposals should be Kx5 forms"); 132 | AT_CHECK(rois.size(1) == 5 && rois.is_contiguous(), "ROI proposals should be Kx5 forms and contiguous"); 133 | AT_CHECK(memory.is_contiguous(), "Memory should be contiguous."); 134 | 135 | 136 | auto grad_in = grad_out.type().tensor({b_size, channel, h, w}); 137 | grad_in.zero_(); 138 | 139 | roi_pool_backward(grad_out.numel(), grad_out.data(), rois.data(), channel, h, w, pool_h, pool_w, 140 | grad_in.data(), memory.data()); 141 | return grad_in; 142 | } 143 | -------------------------------------------------------------------------------- /roipool/roi_pool_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | at::Tensor roi_pool_forward_cuda(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 5 | double scale, at::Tensor &memory); 6 | 7 | at::Tensor roi_pool_backward_cuda(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 8 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, const at::Tensor &memory); 9 | 10 | 11 | // C++ interface 12 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | 16 | at::Tensor roi_pool_forward(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 17 | double scale, at::Tensor &memory) { 18 | CHECK_INPUT(input); 19 | CHECK_INPUT(rois); 20 | CHECK_INPUT(memory); 21 | return roi_pool_forward_cuda(input, rois, pool_h, pool_w, scale, memory); 22 | } 23 | 24 | at::Tensor roi_pool_backward(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 25 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, const at::Tensor &memory) { 26 | CHECK_INPUT(grad_out); 27 | CHECK_INPUT(rois); 28 | CHECK_INPUT(memory); 29 | return roi_pool_backward_cuda(rois, grad_out, b_size, channel, h, w, pool_h, pool_w, memory); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward_cuda", &roi_pool_forward, "roi_pool_forward_cuda"); 34 | m.def("backward_cuda", &roi_pool_backward, "roi_pool_backward_cuda"); 35 | } -------------------------------------------------------------------------------- /roipool/roi_pool_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | template 8 | __device__ __forceinline__ 9 | 10 | T gpu_atomic_add(const T val, T *address) { 11 | return atomicAdd(address, val); 12 | } 13 | 14 | /* ------------------------------begin of the forward--------------------------- */ 15 | template 16 | __global__ void 17 | roi_pool_forward_kernel(const int total, const T *input, const T *rois, const T scale, const int channels, const int h, 18 | const int w, const int pool_h, const int pool_w, T *output, int *memory) { 19 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { 20 | int pw = idx % pool_w; 21 | int ph = (idx / pool_w) % pool_h; 22 | int c = (idx / pool_h / pool_w) % channels; 23 | int n = idx / pool_h / pool_w / channels; 24 | 25 | const T *offset_rois = rois + n * 5; 26 | int roi_batch_idx = offset_rois[0]; 27 | 28 | // using rounding 29 | int roi_start_w = round(offset_rois[1] * scale); 30 | int roi_start_h = round(offset_rois[2] * scale); 31 | int roi_end_w = round(offset_rois[3] * scale); 32 | int roi_end_h = round(offset_rois[4] * scale); 33 | 34 | // Force malformed ROIs to be 1x1 35 | int roi_w = max(roi_end_w - roi_start_w + 1, 1); 36 | int roi_h = max(roi_end_h - roi_start_h + 1, 1); 37 | T bin_size_h = static_cast(roi_h) / static_cast(pool_h); 38 | T bin_size_w = static_cast(roi_w) / static_cast(pool_w); 39 | 40 | int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); 41 | int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); 42 | int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); 43 | int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); 44 | 45 | // Add roi offsets and clip to input boundaries 46 | hstart = min(max(hstart + roi_start_h, 0), h); 47 | hend = min(max(hend + roi_start_h, 0), h); 48 | wstart = min(max(wstart + roi_start_w, 0), w); 49 | wend = min(max(wend + roi_start_w, 0), w); 50 | bool is_empty = (hend <= hstart) || (wend <= wstart); 51 | 52 | // Define an empty pooling region to be zero 53 | T maxval = is_empty ? 0 : -FLT_MAX; 54 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd 55 | int maxidx = -1; 56 | const T *offset_input = input + (roi_batch_idx * channels * c) * h * w; 57 | for (int hi = hstart; hi < hend; ++hi) { 58 | for (int wi = wstart; wi < wend; ++wi) { 59 | int ind = hi * w + wi; 60 | if (offset_input[ind] > maxval) { 61 | maxval = offset_input[ind]; 62 | maxidx = ind; 63 | } 64 | } 65 | } 66 | output[idx] = maxval; 67 | if (memory) { 68 | memory[idx] = maxidx; 69 | } 70 | } 71 | } 72 | 73 | // TODO: there may be a bug 74 | at::Tensor roi_pool_forward_cuda(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 75 | double scale, at::Tensor &memory) { 76 | AT_CHECK(input.ndimension() == 4, "Input features should be BxCxHxW"); 77 | AT_CHECK(rois.ndimension() == 2 && rois.size(1) == 5, "ROIs should be Kx5 forms"); 78 | 79 | auto rois_num = rois.size(0); 80 | auto channel = input.size(1), h = input.size(2), w = input.size(3); 81 | 82 | auto output = input.type().tensor({rois_num, channel, pool_h, pool_w}); 83 | 84 | int64_t total = output.numel(); 85 | const int threads = 1024; 86 | const int64_t blocks = (total + threads - 1) / threads > 65535 ? 65535 : (total + threads - 1) / threads; 87 | 88 | roi_pool_forward_kernel << < blocks, threads >> > (output.numel(), input.data(), rois.data(), 89 | static_cast(scale), channel, h, w, pool_h, pool_w, output.data(), memory.data()); 90 | 91 | AT_CHECK(cudaGetLastError() == cudaSuccess, "roi_align_forward_kernel failed"); 92 | return output; 93 | } 94 | /* ------------------------------end of the forward--------------------------- */ 95 | 96 | /* ------------------------------begin of the backward--------------------------- */ 97 | template 98 | __global__ void roi_pool_backward_kernel(const int total, const T *grad_out, const T *rois, const int channels, 99 | const int h, const int w, const int pool_h, const int pool_w, T *grad_in, 100 | const int *memory) { 101 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { 102 | // (n, c, ph, pw) is an element in the pooled output 103 | int pw = idx % pool_w; 104 | int ph = (idx / pool_w) % pool_h; 105 | int c = (idx / pool_w / pool_h) % channels; 106 | int n = idx / pool_w / pool_h / channels; 107 | 108 | const T *offset_rois = rois + n * 5; 109 | int roi_batch_idx = offset_rois[0]; 110 | // offset of index 111 | int grad_in_offset = (roi_batch_idx * channels + c) * h * w; 112 | int grad_out_offset = (n * channels + c) * pool_h * pool_w; 113 | 114 | const T *offset_grad_out = grad_out + grad_out_offset; 115 | T *offset_grad_in = grad_in + grad_in_offset; 116 | const int *offset_memory = memory + grad_in_offset; 117 | 118 | int argmax = offset_memory[ph * pool_w + pw]; 119 | if (argmax != -1) 120 | gpu_atomic_add(static_cast(offset_grad_out[ph * pool_w + pw]), offset_grad_in + argmax); 121 | } 122 | } // RoIPoolBackward 123 | 124 | at::Tensor roi_pool_backward_cuda(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 125 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, const at::Tensor &memory) { 126 | AT_CHECK(rois.ndimension() == 2 && rois.size(1) == 5, "ROIs should be Kx5 forms"); 127 | AT_CHECK(rois.is_contiguous(), "ROIs should be contiguous"); 128 | 129 | auto grad_in = rois.type().tensor({b_size, channel, h, w}); 130 | grad_in.zero_(); 131 | 132 | int64_t total = grad_out.numel(); 133 | const int threads = 1024; 134 | const int64_t blocks = (total + threads - 1) / threads > 65535 ? 65535 : (total + threads - 1) / threads; 135 | 136 | roi_pool_backward_kernel << < blocks, threads, 0, at::globalContext().getCurrentCUDAStream() >> > (total, 137 | grad_out.data(), rois.data(), channel, h, w, pool_h, pool_w, grad_in.data(), 138 | memory.data()); 139 | 140 | AT_CHECK(cudaGetLastError() == cudaSuccess, "roi_align_forward_kernel failed"); 141 | return grad_in; 142 | } -------------------------------------------------------------------------------- /roipool/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 3 | 4 | setup( 5 | name='roi_pool_cpp', 6 | ext_modules=[ 7 | CppExtension('roi_pool_cpu', [ 8 | 'roi_pool_binding.cpp' 9 | ]), 10 | CUDAExtension('roi_pool_cuda', [ 11 | 'roi_pool_cuda.cpp', 12 | 'roi_pool_kernel.cu', 13 | ]), 14 | ], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) 18 | -------------------------------------------------------------------------------- /roipool/temp.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_TEMP_H 2 | #define ROI_TEMP_H 3 | 4 | #include 5 | 6 | at::Tensor roi_pool_forward_cpu(const at::Tensor &input, const at::Tensor &rois, int64_t pool_h, int64_t pool_w, 7 | double scale, at::Tensor &memory); 8 | 9 | at::Tensor roi_pool_backward_cpu(const at::Tensor &rois, const at::Tensor &grad_out, int64_t b_size, int64_t channel, 10 | int64_t h, int64_t w, int64_t pool_h, int64_t pool_w, at::Tensor &memory); 11 | 12 | #endif //ROI_TEMP_H 13 | --------------------------------------------------------------------------------