├── .gitignore ├── LICENSE ├── README.md ├── demofig.png ├── nms ├── __init__.py ├── build.py ├── nms_wrapper.py ├── nmscpp.cpp ├── nmscppkernel.cu └── pth_nms.py └── roialign ├── __init__.py └── roi_align ├── __init__.py ├── build.py ├── crop_and_resize.py ├── roi_align.py ├── roicpp.cpp └── roicppkernel.cu /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-extension-nms-roi-cpp 2 | C++ port of nms and roi_align extensions (CUDA ver.) compilable in Windows 10 with VS2017 3 | 4 | An excercise to port 2 pytorch C extensions to C++ and compatible with Win10+VS2017. May contain bugs. For use with 5 | 6 | https://github.com/multimodallearning/pytorch-mask-rcnn 7 | 8 | go into each folder and run ```python build.py install``` changing include_dirs and include_libs to local VS2017 directories. 9 | 10 | now copy paste the whole folder including python files to pytorch-mask-rcnn and it should work (only cuda version is ported so cuda is needed). 11 | 12 | Result from running demo.py using the two extensions : 13 | 14 | ![Test Demo.py](./demofig.png) 15 | -------------------------------------------------------------------------------- /demofig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/darlliu/pytorch-extension-nms-roi-cpp/fed15d717f7e2ddac80239170861c66fcd72d342/demofig.png -------------------------------------------------------------------------------- /nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/darlliu/pytorch-extension-nms-roi-cpp/fed15d717f7e2ddac80239170861c66fcd72d342/nms/__init__.py -------------------------------------------------------------------------------- /nms/build.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup ( 5 | name = 'nmscuda', 6 | ext_modules = [ 7 | CUDAExtension( 8 | name = "nmscuda", 9 | sources = ["./nmscpp.cpp","./nmscppkernel.cu"], 10 | extra_compile_args = {'cxx':["-DMS_WIN64","-MD" ], "nvcc":["-O2"]}, 11 | include_dirs = [ 12 | "D:/Python37/Lib/site-packages/torch/lib/include", 13 | "D:/VisualStudio/VS/VC/Tools/MSVC/14.15.26726/include", 14 | "D:/Windows Kits/10/Include/10.0.17134.0/ucrt", 15 | "D:/Windows Kits/10/Include/10.0.17134.0/shared"], 16 | library_dirs = [ 17 | "D:/Python37/Lib/site-packages/torch/lib", 18 | "D:/Windows Kits/10/Lib/10.0.17134.0/ucrt/x64", 19 | "D:/Windows Kits/10/Lib/10.0.17134.0/um/x64", 20 | "D:/VisualStudio/VS/VC/Tools/MSVC/14.15.26726/lib/x64"], 21 | ) 22 | ], 23 | cmdclass = { 24 | "build_ext":BuildExtension 25 | } 26 | ) -------------------------------------------------------------------------------- /nms/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from nms.pth_nms import pth_nms 12 | 13 | 14 | def nms(dets, thresh): 15 | """Dispatch to either CPU or GPU NMS implementations. 16 | Accept dets as tensor""" 17 | return pth_nms(dets, thresh) 18 | -------------------------------------------------------------------------------- /nms/nmscpp.cpp: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // Ported to C++, by YL 7 | // ------------------------------------------------------------------ 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | typedef float scalar_t ; 15 | typedef int64_t intlike_t ; 16 | 17 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 18 | int const threadsPerBlock = sizeof(intlike_t) * 8; 19 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | void _nms(int boxes_num, scalar_t * boxes_dev, 24 | intlike_t * mask_dev, float nms_overlap_thresh); 25 | int gpu_nms( 26 | at::Tensor& keep, 27 | at::Tensor& num_out, 28 | at::Tensor& boxes, 29 | float nms_overlap_thresh 30 | ) 31 | { 32 | CHECK_INPUT (boxes); 33 | 34 | // Number of ROIs 35 | auto boxes_num = boxes.size(0); 36 | auto boxes_dim = boxes.size(1); 37 | 38 | auto boxes_flat = boxes.data(); 39 | 40 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 41 | 42 | auto mask = at::CUDA(at::kLong).zeros({boxes_num, col_blocks}); 43 | auto mask_flat = mask.data(); 44 | _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh); 45 | 46 | 47 | // std :: cout << "completed the cuda kernel "<(); 50 | 51 | auto remv_cpu = at::CPU(at::kLong).zeros({col_blocks}); 52 | auto remv_cpu_flat = remv_cpu.data(); 53 | 54 | auto keep_flat = keep.data(); 55 | intlike_t num_to_keep = 0; 56 | 57 | // std :: cout << "completed setting up keep_flat "<(); 73 | * num_out_flat = num_to_keep; 74 | 75 | return 1; 76 | } 77 | 78 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 79 | m.def("gpu_nms", &gpu_nms, "NMS with CUDA kernel"); 80 | } -------------------------------------------------------------------------------- /nms/nmscppkernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | #include 8 | #include 9 | #include 10 | #include 11 | typedef float scalar_t; 12 | typedef int64_t intlike_t ; 13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 14 | int const threadsPerBlock = sizeof(intlike_t) * 8; 15 | 16 | __device__ inline float devIoU(float const * const a, float const * const b) { 17 | float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); 18 | float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); 19 | float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); 20 | float interS = width * height; 21 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 22 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 23 | return interS / (Sa + Sb - interS); 24 | } 25 | 26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 27 | scalar_t *dev_boxes, intlike_t *dev_mask) { 28 | const int row_start = blockIdx.y; 29 | const int col_start = blockIdx.x; 30 | 31 | // if (row_start > col_start) return; 32 | 33 | const int row_size = 34 | fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 35 | const int col_size = 36 | fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 37 | 38 | __shared__ float block_boxes[threadsPerBlock * 5]; 39 | if (threadIdx.x < col_size) { 40 | block_boxes[threadIdx.x * 5 + 0] = 41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 42 | block_boxes[threadIdx.x * 5 + 1] = 43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 44 | block_boxes[threadIdx.x * 5 + 2] = 45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 46 | block_boxes[threadIdx.x * 5 + 3] = 47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 48 | block_boxes[threadIdx.x * 5 + 4] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 50 | } 51 | __syncthreads(); 52 | 53 | if (threadIdx.x < row_size) { 54 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 55 | const float *cur_box = dev_boxes + cur_box_idx * 5; 56 | int i = 0; 57 | intlike_t t = 0; 58 | int start = 0; 59 | if (row_start == col_start) { 60 | start = threadIdx.x + 1; 61 | } 62 | for (i = start; i < col_size; i++) { 63 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 64 | t |= 1ULL << i; 65 | } 66 | } 67 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 68 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 69 | } 70 | } 71 | 72 | 73 | void _nms(int boxes_num, scalar_t * boxes_dev, 74 | intlike_t * mask_dev, float nms_overlap_thresh) { 75 | 76 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 77 | DIVUP(boxes_num, threadsPerBlock)); 78 | dim3 threads(threadsPerBlock); 79 | nms_kernel<<>>(boxes_num, 80 | nms_overlap_thresh, 81 | boxes_dev, 82 | mask_dev); 83 | } -------------------------------------------------------------------------------- /nms/pth_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nmscuda import gpu_nms as nms 3 | import numpy as np 4 | 5 | def pth_nms(dets, thresh): 6 | """ 7 | dets has to be a tensor 8 | """ 9 | if not dets.is_cuda: 10 | x1 = dets[:, 1] 11 | y1 = dets[:, 0] 12 | x2 = dets[:, 3] 13 | y2 = dets[:, 2] 14 | scores = dets[:, 4] 15 | 16 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 17 | order = scores.sort(0, descending=True)[1] 18 | # order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long() 19 | 20 | keep = torch.LongTensor(dets.size(0)) 21 | num_out = torch.LongTensor(1) 22 | nms.cpu_nms(keep, num_out, dets, order, areas, thresh) 23 | 24 | return keep[:num_out[0]] 25 | else: 26 | x1 = dets[:, 1] 27 | y1 = dets[:, 0] 28 | x2 = dets[:, 3] 29 | y2 = dets[:, 2] 30 | scores = dets[:, 4] 31 | 32 | dets_temp = torch.FloatTensor(dets.size()).cuda() 33 | dets_temp[:, 0] = dets[:, 1] 34 | dets_temp[:, 1] = dets[:, 0] 35 | dets_temp[:, 2] = dets[:, 3] 36 | dets_temp[:, 3] = dets[:, 2] 37 | dets_temp[:, 4] = dets[:, 4] 38 | 39 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 40 | order = scores.sort(0, descending=True)[1] 41 | # order = torch.from_numpy(np.ascontiguousarray(scores.cpu().numpy().argsort()[::-1])).long().cuda() 42 | 43 | dets = dets[order].contiguous() 44 | 45 | keep = torch.LongTensor(dets.size(0)) 46 | num_out = torch.LongTensor(1) 47 | # keep = torch.cuda.LongTensor(dets.size(0)) 48 | # num_out = torch.cuda.LongTensor(1) 49 | # print ("starting nms with shapes {}, {}".format(keep.shape, dets_temp.shape)) 50 | nms(keep, num_out, dets_temp, thresh) 51 | # print ("finished nms with shapes {}, {}, num_out is {}".format(keep.shape, dets_temp.shape, num_out)) 52 | return order[keep[:num_out[0]].cuda()].contiguous() 53 | # return order[keep[:num_out[0]]].contiguous() 54 | 55 | -------------------------------------------------------------------------------- /roialign/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/darlliu/pytorch-extension-nms-roi-cpp/fed15d717f7e2ddac80239170861c66fcd72d342/roialign/__init__.py -------------------------------------------------------------------------------- /roialign/roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/darlliu/pytorch-extension-nms-roi-cpp/fed15d717f7e2ddac80239170861c66fcd72d342/roialign/roi_align/__init__.py -------------------------------------------------------------------------------- /roialign/roi_align/build.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup ( 5 | name = 'roicuda', 6 | ext_modules = [ 7 | CUDAExtension( 8 | name = "roicuda", 9 | sources = ["./roicpp.cpp","./roicppkernel.cu"], 10 | extra_compile_args = {'cxx':["-DMS_WIN64","-MD" ], "nvcc":["-O2"]}, 11 | include_dirs = [ 12 | "D:/Python37/Lib/site-packages/torch/lib/include", 13 | "D:/VisualStudio/VS/VC/Tools/MSVC/14.15.26726/include", 14 | "D:/Windows Kits/10/Include/10.0.17134.0/ucrt", 15 | "D:/Windows Kits/10/Include/10.0.17134.0/shared"], 16 | library_dirs = [ 17 | "D:/Python37/Lib/site-packages/torch/lib", 18 | "D:/Windows Kits/10/Lib/10.0.17134.0/ucrt/x64", 19 | "D:/Windows Kits/10/Lib/10.0.17134.0/um/x64", 20 | "D:/VisualStudio/VS/VC/Tools/MSVC/14.15.26726/lib/x64"], 21 | ) 22 | ], 23 | cmdclass = { 24 | "build_ext":BuildExtension 25 | } 26 | ) -------------------------------------------------------------------------------- /roialign/roi_align/crop_and_resize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | from roicuda import gpu_roi_forward,gpu_roi_backward 8 | 9 | 10 | class CropAndResizeFunction(Function): 11 | 12 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 13 | self.crop_height = crop_height 14 | self.crop_width = crop_width 15 | self.extrapolation_value = extrapolation_value 16 | 17 | def forward(self, image, boxes, box_ind): 18 | crops = None 19 | 20 | if image.is_cuda: 21 | crops = gpu_roi_forward( 22 | image, boxes, box_ind, 23 | self.extrapolation_value, self.crop_height, self.crop_width) 24 | else: 25 | raise NotImplementedError("CPU version is currently not supported") 26 | 27 | # save for backward 28 | self.im_size = image.size() 29 | self.save_for_backward(boxes, box_ind) 30 | # print ("got crops back into python, shape {}".format(crops.shape)) 31 | return crops 32 | 33 | def backward(self, grad_outputs): 34 | boxes, box_ind = self.saved_tensors 35 | 36 | grad_outputs = grad_outputs.contiguous() 37 | grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) 38 | 39 | if grad_outputs.is_cuda: 40 | # print ("about to backprop grads_output with shape {}, grads_mage with shape {}".format(grad_outputs.shape, grad_image.shape)) 41 | grad_image = gpu_roi_backward( 42 | grad_outputs, boxes, box_ind, grad_image 43 | ) 44 | else: 45 | raise NotImplementedError("CPU version is currently not supported") 46 | 47 | return grad_image, None, None 48 | 49 | 50 | class CropAndResize(nn.Module): 51 | """ 52 | Crop and resize ported from tensorflow 53 | See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize 54 | """ 55 | 56 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 57 | super(CropAndResize, self).__init__() 58 | 59 | self.crop_height = crop_height 60 | self.crop_width = crop_width 61 | self.extrapolation_value = extrapolation_value 62 | 63 | def forward(self, image, boxes, box_ind): 64 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) 65 | -------------------------------------------------------------------------------- /roialign/roi_align/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from crop_and_resize import CropAndResizeFunction, CropAndResize 5 | 6 | 7 | class RoIAlign(nn.Module): 8 | 9 | def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): 10 | super(RoIAlign, self).__init__() 11 | 12 | self.crop_height = crop_height 13 | self.crop_width = crop_width 14 | self.extrapolation_value = extrapolation_value 15 | self.transform_fpcoor = transform_fpcoor 16 | 17 | def forward(self, featuremap, boxes, box_ind): 18 | """ 19 | RoIAlign based on crop_and_resize. 20 | See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 21 | :param featuremap: NxCxHxW 22 | :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** 23 | :param box_ind: M 24 | :return: MxCxoHxoW 25 | """ 26 | x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) 27 | image_height, image_width = featuremap.size()[2:4] 28 | 29 | if self.transform_fpcoor: 30 | spacing_w = (x2 - x1) / float(self.crop_width) 31 | spacing_h = (y2 - y1) / float(self.crop_height) 32 | 33 | nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) 34 | ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) 35 | nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) 36 | nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) 37 | 38 | boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) 39 | else: 40 | x1 = x1 / float(image_width - 1) 41 | x2 = x2 / float(image_width - 1) 42 | y1 = y1 / float(image_height - 1) 43 | y2 = y2 / float(image_height - 1) 44 | boxes = torch.cat((y1, x1, y2, x2), 1) 45 | 46 | boxes = boxes.detach().contiguous() 47 | box_ind = box_ind.detach() 48 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) 49 | -------------------------------------------------------------------------------- /roialign/roi_align/roicpp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | void CropAndResizeBackpropImageLaucher( 8 | const float *grads_ptr, const float *boxes_ptr, 9 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 10 | int image_width, int crop_height, int crop_width, int depth, 11 | float *grads_image_ptr); 12 | 13 | 14 | 15 | void CropAndResizeLaucher( 16 | const float *image_ptr, const float *boxes_ptr, 17 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 18 | int image_width, int crop_height, int crop_width, int depth, 19 | float extrapolation_value, float *crops_ptr); 20 | 21 | 22 | at::Tensor crop_and_resize_gpu_forward( 23 | at::Tensor image, 24 | at::Tensor boxes, // [y1, x1, y2, x2] 25 | at::Tensor box_index, // range in [0, batch_size) 26 | const float extrapolation_value, 27 | const int crop_height, 28 | const int crop_width 29 | ) { 30 | const int batch_size = image.size(0); 31 | const int depth = image.size(1); 32 | const int image_height = image.size(2); 33 | const int image_width = image.size(3); 34 | 35 | const int num_boxes = boxes.size(0); 36 | // init output space 37 | auto crops = torch::CUDA(at::kFloat).zeros({num_boxes, depth, crop_height, crop_width}); 38 | CropAndResizeLaucher( 39 | image.data(), 40 | boxes.data(), 41 | box_index.data(), 42 | num_boxes, batch_size, image_height, image_width, 43 | crop_height, crop_width, depth, extrapolation_value, 44 | crops.data() 45 | ); 46 | return crops; 47 | } 48 | 49 | 50 | at::Tensor crop_and_resize_gpu_backward( 51 | at::Tensor grads, 52 | at::Tensor boxes, // [y1, x1, y2, x2] 53 | at::Tensor box_index, // range in [0, batch_size) 54 | at::Tensor grads_image // resize to [bsize, c, hc, wc], CPU 55 | ) { 56 | // shape 57 | // std::cout << "about to launch backprop roi"<(), 71 | boxes.data(), 72 | box_index.data(), 73 | num_boxes, batch_size, image_height, image_width, 74 | crop_height, crop_width, depth, 75 | grads_cuda.data() 76 | ); 77 | return grads_cuda; //.toBackend(at::Backend::CPU); 78 | } 79 | 80 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 81 | m.def("gpu_roi_forward", &crop_and_resize_gpu_forward, "NMS with CUDA kernel"); 82 | m.def("gpu_roi_backward", &crop_and_resize_gpu_backward, "NMS with CUDA kernel"); 83 | } -------------------------------------------------------------------------------- /roialign/roi_align/roicppkernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 6 | i += blockDim.x * gridDim.x) 7 | 8 | 9 | __global__ 10 | void CropAndResizeKernel( 11 | const int nthreads, const float *image_ptr, const float *boxes_ptr, 12 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 13 | int image_width, int crop_height, int crop_width, int depth, 14 | float extrapolation_value, float *crops_ptr) 15 | { 16 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 17 | { 18 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 19 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 20 | int idx = out_idx; 21 | const int x = idx % crop_width; 22 | idx /= crop_width; 23 | const int y = idx % crop_height; 24 | idx /= crop_height; 25 | const int d = idx % depth; 26 | const int b = idx / depth; 27 | 28 | const float y1 = boxes_ptr[b * 4]; 29 | const float x1 = boxes_ptr[b * 4 + 1]; 30 | const float y2 = boxes_ptr[b * 4 + 2]; 31 | const float x2 = boxes_ptr[b * 4 + 3]; 32 | 33 | const int b_in = box_ind_ptr[b]; 34 | if (b_in < 0 || b_in >= batch) 35 | { 36 | continue; 37 | } 38 | 39 | const float height_scale = 40 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 41 | : 0; 42 | const float width_scale = 43 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 44 | 45 | const float in_y = (crop_height > 1) 46 | ? y1 * (image_height - 1) + y * height_scale 47 | : 0.5 * (y1 + y2) * (image_height - 1); 48 | if (in_y < 0 || in_y > image_height - 1) 49 | { 50 | crops_ptr[out_idx] = extrapolation_value; 51 | continue; 52 | } 53 | 54 | const float in_x = (crop_width > 1) 55 | ? x1 * (image_width - 1) + x * width_scale 56 | : 0.5 * (x1 + x2) * (image_width - 1); 57 | if (in_x < 0 || in_x > image_width - 1) 58 | { 59 | crops_ptr[out_idx] = extrapolation_value; 60 | continue; 61 | } 62 | 63 | const int top_y_index = floorf(in_y); 64 | const int bottom_y_index = ceilf(in_y); 65 | const float y_lerp = in_y - top_y_index; 66 | 67 | const int left_x_index = floorf(in_x); 68 | const int right_x_index = ceilf(in_x); 69 | const float x_lerp = in_x - left_x_index; 70 | 71 | const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; 72 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 73 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 74 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 75 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 76 | 77 | const float top = top_left + (top_right - top_left) * x_lerp; 78 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 79 | crops_ptr[out_idx] = top + (bottom - top) * y_lerp; 80 | } 81 | } 82 | 83 | __global__ 84 | void CropAndResizeBackpropImageKernel( 85 | const int nthreads, const float *grads_ptr, const float *boxes_ptr, 86 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 87 | int image_width, int crop_height, int crop_width, int depth, 88 | float *grads_image_ptr) 89 | { 90 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 91 | { 92 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 93 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 94 | int idx = out_idx; 95 | const int x = idx % crop_width; 96 | idx /= crop_width; 97 | const int y = idx % crop_height; 98 | idx /= crop_height; 99 | const int d = idx % depth; 100 | const int b = idx / depth; 101 | 102 | const float y1 = boxes_ptr[b * 4]; 103 | const float x1 = boxes_ptr[b * 4 + 1]; 104 | const float y2 = boxes_ptr[b * 4 + 2]; 105 | const float x2 = boxes_ptr[b * 4 + 3]; 106 | 107 | const int b_in = box_ind_ptr[b]; 108 | if (b_in < 0 || b_in >= batch) 109 | { 110 | continue; 111 | } 112 | 113 | const float height_scale = 114 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 115 | : 0; 116 | const float width_scale = 117 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 118 | 119 | const float in_y = (crop_height > 1) 120 | ? y1 * (image_height - 1) + y * height_scale 121 | : 0.5 * (y1 + y2) * (image_height - 1); 122 | if (in_y < 0 || in_y > image_height - 1) 123 | { 124 | continue; 125 | } 126 | 127 | const float in_x = (crop_width > 1) 128 | ? x1 * (image_width - 1) + x * width_scale 129 | : 0.5 * (x1 + x2) * (image_width - 1); 130 | if (in_x < 0 || in_x > image_width - 1) 131 | { 132 | continue; 133 | } 134 | 135 | const int top_y_index = floorf(in_y); 136 | const int bottom_y_index = ceilf(in_y); 137 | const float y_lerp = in_y - top_y_index; 138 | 139 | const int left_x_index = floorf(in_x); 140 | const int right_x_index = ceilf(in_x); 141 | const float x_lerp = in_x - left_x_index; 142 | 143 | float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; 144 | const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; 145 | atomicAdd( 146 | pimage + top_y_index * image_width + left_x_index, 147 | (1 - x_lerp) * dtop 148 | ); 149 | atomicAdd( 150 | pimage + top_y_index * image_width + right_x_index, 151 | x_lerp * dtop 152 | ); 153 | 154 | const float dbottom = y_lerp * grads_ptr[out_idx]; 155 | atomicAdd( 156 | pimage + bottom_y_index * image_width + left_x_index, 157 | (1 - x_lerp) * dbottom 158 | ); 159 | atomicAdd( 160 | pimage + bottom_y_index * image_width + right_x_index, 161 | x_lerp * dbottom 162 | ); 163 | } 164 | } 165 | 166 | 167 | void CropAndResizeLaucher( 168 | const float *image_ptr, const float *boxes_ptr, 169 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 170 | int image_width, int crop_height, int crop_width, int depth, 171 | float extrapolation_value, float *crops_ptr) 172 | { 173 | const int total_count = num_boxes * crop_height * crop_width * depth; 174 | const int thread_per_block = 1024; 175 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 176 | cudaError_t err; 177 | 178 | if (total_count > 0) 179 | { 180 | CropAndResizeKernel<<>>( 181 | total_count, image_ptr, boxes_ptr, 182 | box_ind_ptr, num_boxes, batch, image_height, image_width, 183 | crop_height, crop_width, depth, extrapolation_value, crops_ptr); 184 | 185 | err = cudaGetLastError(); 186 | if (cudaSuccess != err) 187 | { 188 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 189 | exit(-1); 190 | } 191 | } 192 | } 193 | 194 | 195 | void CropAndResizeBackpropImageLaucher( 196 | const float *grads_ptr, const float *boxes_ptr, 197 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 198 | int image_width, int crop_height, int crop_width, int depth, 199 | float *grads_image_ptr) 200 | { 201 | const int total_count = num_boxes * crop_height * crop_width * depth; 202 | const int thread_per_block = 1024; 203 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 204 | cudaError_t err; 205 | 206 | if (total_count > 0) 207 | { 208 | CropAndResizeBackpropImageKernel<<>>( 209 | total_count, grads_ptr, boxes_ptr, 210 | box_ind_ptr, num_boxes, batch, image_height, image_width, 211 | crop_height, crop_width, depth, grads_image_ptr); 212 | 213 | err = cudaGetLastError(); 214 | if (cudaSuccess != err) 215 | { 216 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 217 | exit(-1); 218 | } 219 | } 220 | } --------------------------------------------------------------------------------