├── .gitignore ├── functions ├── __init__.py ├── deform_conv.py └── modulated_dcn_func.py ├── modules ├── __init__.py ├── deform_conv.py └── modulated_dcn.py ├── src ├── deform_conv.h ├── deform_conv.c ├── deform_conv_cuda.h ├── modulated_dcn.h ├── modulated_dcn.c ├── deform_conv_cuda_kernel.h ├── cuda │ ├── deform_psroi_pooling_cuda.h │ ├── modulated_deform_im2col_cuda.h │ ├── deform_psroi_pooling_cuda.cu │ └── modulated_deform_im2col_cuda.cu ├── modulated_dcn_cuda.h ├── modulated_dcn_cuda.c ├── deform_conv_cuda.c └── deform_conv_cuda_kernel.cu ├── make.sh ├── test.py ├── LICENSE ├── README.md └── test_modulated.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.so 3 | _ext* 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import DeformConvFunction, deform_conv_function 2 | from .modulated_dcn_func import DeformRoIPoolingFunction, ModulatedDeformConvFunction -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import DeformConv 2 | from .modulated_dcn import DeformRoIPooling, ModulatedDeformConv, ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack -------------------------------------------------------------------------------- /src/deform_conv.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 2 | THFloatTensor *output); 3 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 4 | THFloatTensor *grad_offset); 5 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | cd src 2 | nvcc -c -o deform_conv_cuda_kernel.cu.so deform_conv_cuda_kernel.cu -x cu -Xcompiler -fPIC -std=c++11 3 | 4 | cd cuda 5 | 6 | # compile modulated deform conv 7 | nvcc -c -o modulated_deform_im2col_cuda.cu.so modulated_deform_im2col_cuda.cu -x cu -Xcompiler -fPIC 8 | 9 | # compile deform-psroi-pooling 10 | nvcc -c -o deform_psroi_pooling_cuda.cu.so deform_psroi_pooling_cuda.cu -x cu -Xcompiler -fPIC 11 | 12 | cd ../.. 13 | CC=g++ python build.py 14 | python build_modulated.py 15 | -------------------------------------------------------------------------------- /src/deform_conv.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 4 | THFloatTensor *output) 5 | { 6 | // if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | // return 0; 8 | // THFloatTensor_resizeAs(output, input); 9 | // THFloatTensor_cadd(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 14 | THFloatTensor *grad_offset) 15 | { 16 | // THFloatTensor_resizeAs(grad_input, grad_output); 17 | // THFloatTensor_fill(grad_input, 1); 18 | return 1; 19 | } 20 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from modules import DeformConv 7 | 8 | num_deformable_groups = 2 9 | 10 | N, inC, inH, inW = 2, 6, 512, 512 11 | outC, outH, outW = 4, 512, 512 12 | kH, kW = 3, 3 13 | 14 | conv = nn.Conv2d( 15 | inC, 16 | num_deformable_groups * 2 * kH * kW, 17 | kernel_size=(kH, kW), 18 | stride=(1, 1), 19 | padding=(1, 1), 20 | bias=False).cuda() 21 | 22 | conv_offset2d = DeformConv( 23 | inC, 24 | outC, (kH, kW), 25 | stride=1, 26 | padding=1, 27 | num_deformable_groups=num_deformable_groups).cuda() 28 | 29 | inputs = Variable(torch.randn(N, inC, inH, inW).cuda(), requires_grad=True) 30 | offset = conv(inputs) 31 | #offset = Variable(torch.randn(N, num_deformable_groups * 2 * kH * kW, inH, inW).cuda(), requires_grad=True) 32 | output = conv_offset2d(inputs, offset) 33 | output.backward(output.data) 34 | print(output.size()) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020 Dazhi Cheng 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /src/deform_conv_cuda.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward_cuda(THCudaTensor *input, 2 | THCudaTensor *weight, /*THCudaTensor * bias, */ 3 | THCudaTensor *offset, THCudaTensor *output, 4 | THCudaTensor *columns, THCudaTensor *ones, int kW, 5 | int kH, int dW, int dH, int padW, int padH, 6 | int dilationW, int dilationH, 7 | int deformable_group, int im2col_step); 8 | 9 | int deform_conv_backward_input_cuda( 10 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 11 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight, 12 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH, 13 | int dilationW, int dilationH, int deformable_group, int im2col_step); 14 | 15 | int deform_conv_backward_parameters_cuda( 16 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 17 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */ 18 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH, 19 | int padW, int padH, int dilationW, int dilationH, int deformable_group, 20 | float scale, int im2col_step); 21 | -------------------------------------------------------------------------------- /src/modulated_dcn.h: -------------------------------------------------------------------------------- 1 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight, 2 | THFloatTensor *bias, THFloatTensor *ones, 3 | THFloatTensor *offset, THFloatTensor *mask, 4 | THFloatTensor *output, THFloatTensor *columns, 5 | const int pad_h, const int pad_w, 6 | const int stride_h, const int stride_w, 7 | const int dilation_h, const int dilation_w, 8 | const int deformable_group); 9 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight, 10 | THFloatTensor *bias, THFloatTensor *ones, 11 | THFloatTensor *offset, THFloatTensor *mask, 12 | THFloatTensor *output, THFloatTensor *columns, 13 | THFloatTensor *grad_input, THFloatTensor *grad_weight, 14 | THFloatTensor *grad_bias, THFloatTensor *grad_offset, 15 | THFloatTensor *grad_mask, THFloatTensor *grad_output, 16 | int kernel_h, int kernel_w, 17 | int stride_h, int stride_w, 18 | int pad_h, int pad_w, 19 | int dilation_h, int dilation_w, 20 | int deformable_group); -------------------------------------------------------------------------------- /modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.utils import _pair 7 | from functions import deform_conv_function 8 | 9 | 10 | class DeformConv(Module): 11 | def __init__(self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | num_deformable_groups=1): 19 | super(DeformConv, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = _pair(stride) 24 | self.padding = _pair(padding) 25 | self.dilation = _pair(dilation) 26 | self.num_deformable_groups = num_deformable_groups 27 | 28 | self.weight = nn.Parameter( 29 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | n = self.in_channels 35 | for k in self.kernel_size: 36 | n *= k 37 | stdv = 1. / math.sqrt(n) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, input, offset): 41 | return deform_conv_function(input, offset, self.weight, self.stride, 42 | self.padding, self.dilation, 43 | self.num_deformable_groups) 44 | -------------------------------------------------------------------------------- /src/modulated_dcn.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight, 6 | THFloatTensor *bias, THFloatTensor *ones, 7 | THFloatTensor *offset, THFloatTensor *mask, 8 | THFloatTensor *output, THFloatTensor *columns, 9 | const int pad_h, const int pad_w, 10 | const int stride_h, const int stride_w, 11 | const int dilation_h, const int dilation_w, 12 | const int deformable_group) 13 | { 14 | printf("only implemented in GPU"); 15 | } 16 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight, 17 | THFloatTensor *bias, THFloatTensor *ones, 18 | THFloatTensor *offset, THFloatTensor *mask, 19 | THFloatTensor *output, THFloatTensor *columns, 20 | THFloatTensor *grad_input, THFloatTensor *grad_weight, 21 | THFloatTensor *grad_bias, THFloatTensor *grad_offset, 22 | THFloatTensor *grad_mask, THFloatTensor *grad_output, 23 | int kernel_h, int kernel_w, 24 | int stride_h, int stride_w, 25 | int pad_h, int pad_w, 26 | int dilation_h, int dilation_w, 27 | int deformable_group) 28 | { 29 | printf("only implemented in GPU"); 30 | } -------------------------------------------------------------------------------- /src/deform_conv_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | template 2 | void deformable_im2col(cudaStream_t stream, const DType *data_im, 3 | const DType *data_offset, const int channels, 4 | const int height, const int width, const int ksize_h, 5 | const int ksize_w, const int pad_h, const int pad_w, 6 | const int stride_h, const int stride_w, 7 | const int dilation_h, const int dilation_w, 8 | const int parallel_imgs, 9 | const int deformable_group, DType *data_col); 10 | 11 | template 12 | void deformable_col2im(cudaStream_t stream, const DType *data_col, 13 | const DType *data_offset, const int channels, 14 | const int height, const int width, const int ksize_h, 15 | const int ksize_w, const int pad_h, const int pad_w, 16 | const int stride_h, const int stride_w, 17 | const int dilation_h, const int dilation_w, 18 | const int parallel_imgs, 19 | const int deformable_group, DType *grad_im); 20 | 21 | template 22 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col, 23 | const DType *data_im, const DType *data_offset, 24 | const int channels, const int height, 25 | const int width, const int ksize_h, 26 | const int ksize_w, const int pad_h, 27 | const int pad_w, const int stride_h, 28 | const int stride_w, const int dilation_h, 29 | const int dilation_w, const int parallel_imgs, 30 | const int deformable_group, DType *grad_offset); 31 | -------------------------------------------------------------------------------- /src/cuda/deform_psroi_pooling_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | 10 | #ifndef DCN_V2_PSROI_POOLING_CUDA 11 | #define DCN_V2_PSROI_POOLING_CUDA 12 | 13 | #ifdef __cplusplus 14 | extern "C" 15 | { 16 | #endif 17 | 18 | void DeformablePSROIPoolForward(cudaStream_t stream, 19 | const float *data, 20 | const float *bbox, 21 | const float *trans, 22 | float *out, 23 | float *top_count, 24 | const int batch, 25 | const int channels, 26 | const int height, 27 | const int width, 28 | const int num_bbox, 29 | const int channels_trans, 30 | const int no_trans, 31 | const float spatial_scale, 32 | const int output_dim, 33 | const int group_size, 34 | const int pooled_size, 35 | const int part_size, 36 | const int sample_per_part, 37 | const float trans_std); 38 | 39 | void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, 40 | const float *out_grad, 41 | const float *data, 42 | const float *bbox, 43 | const float *trans, 44 | const float *top_count, 45 | float *in_grad, 46 | float *trans_grad, 47 | const int batch, 48 | const int channels, 49 | const int height, 50 | const int width, 51 | const int num_bbox, 52 | const int channels_trans, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); 61 | 62 | #ifdef __cplusplus 63 | } 64 | #endif 65 | 66 | #endif -------------------------------------------------------------------------------- /src/modulated_dcn_cuda.h: -------------------------------------------------------------------------------- 1 | // #ifndef DCN_V2_CUDA 2 | // #define DCN_V2_CUDA 3 | 4 | // #ifdef __cplusplus 5 | // extern "C" 6 | // { 7 | // #endif 8 | 9 | void modulated_deform_conv_cuda_forward(THCudaTensor *input, THCudaTensor *weight, 10 | THCudaTensor *bias, THCudaTensor *ones, 11 | THCudaTensor *offset, THCudaTensor *mask, 12 | THCudaTensor *output, THCudaTensor *columns, 13 | int kernel_h, int kernel_w, 14 | const int stride_h, const int stride_w, 15 | const int pad_h, const int pad_w, 16 | const int dilation_h, const int dilation_w, 17 | const int deformable_group); 18 | void modulated_deform_conv_cuda_backward(THCudaTensor *input, THCudaTensor *weight, 19 | THCudaTensor *bias, THCudaTensor *ones, 20 | THCudaTensor *offset, THCudaTensor *mask, 21 | THCudaTensor *columns, 22 | THCudaTensor *grad_input, THCudaTensor *grad_weight, 23 | THCudaTensor *grad_bias, THCudaTensor *grad_offset, 24 | THCudaTensor *grad_mask, THCudaTensor *grad_output, 25 | int kernel_h, int kernel_w, 26 | int stride_h, int stride_w, 27 | int pad_h, int pad_w, 28 | int dilation_h, int dilation_w, 29 | int deformable_group); 30 | 31 | void deform_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, 32 | THCudaTensor * trans, 33 | THCudaTensor * out, THCudaTensor * top_count, 34 | const int no_trans, 35 | const float spatial_scale, 36 | const int output_dim, 37 | const int group_size, 38 | const int pooled_size, 39 | const int part_size, 40 | const int sample_per_part, 41 | const float trans_std); 42 | 43 | void deform_psroi_pooling_cuda_backward(THCudaTensor * out_grad, 44 | THCudaTensor * input, THCudaTensor * bbox, 45 | THCudaTensor * trans, THCudaTensor * top_count, 46 | THCudaTensor * input_grad, THCudaTensor * trans_grad, 47 | const int no_trans, 48 | const float spatial_scale, 49 | const int output_dim, 50 | const int group_size, 51 | const int pooled_size, 52 | const int part_size, 53 | const int sample_per_part, 54 | const float trans_std); 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deformable-ConvNets-V2 in PyTorch 2 | 3 | This repo is an implementation of [Deformable Convolution V2](https://arxiv.org/abs/1811.11168). 4 | Ported from the original [MXNet implementation](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). 5 | 6 | Refer to [mmdetection branch](https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/mmdetection) in this repo for a complete framework. Results of DCNv2 based on mmdetection code base can be found at [model zoo](https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/MODEL_ZOO.md#deformable-conv-v2). Many thanks to [mmdetection](https://github.com/open-mmlab/mmdetection) for their strong and clean framework. 7 | 8 | Operators in master branch are compatible with pytorch_v0.4.1. For operators on pytorch v1.0.0 (implemented by [Jiarui Xu](https://github.com/xvjiarui)), please refer to [pytorch_1.0.0 branch](https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0). 9 | 10 | Thanks to [Kai Chen](https://github.com/hellock) and other contributors from mmlab, DCNv2 is now included in the official mmdetection repo based on the master branch of this one. It is now written with the new cpp extension apis and it supports both PyTorch 0.4.1 and 1.0, with some minor speed and memory optimization. Results and models can be found at https://github.com/open-mmlab/mmdetection/tree/master/configs/dcn. 11 | 12 | ## Build 13 | 14 | ``` 15 | sh make.sh 16 | ``` 17 | 18 | See `test.py` and `test_modulated.py` for example usage. 19 | 20 | ## Notice 21 | 22 | This repo provides the deformable conv layer which can reproduce the results in the Deformable ConvNets v2 paper. The major changes are as follows: 23 | 24 | * To better handle occasions where sampling locations are outside of the image boundary. 25 | 26 | In the previous operator, if the sampling location is outside of the feature map boundary, its sampled value would be zero. Thus, the gradient with respect to learnable offset would be zero. We found such a scheme may deteriate the performance in ImageNet classification (perhaps because the feature maps are of low resolution). For object detection on COCO, both the previous and the updated operators deliver the same results. 27 | 28 | In the new operator, if the sampling location is within one pixel outside of the feature map boundary, bilinear sampling would also be applied. And gradient with respect to learnable offset can be non zero for such locations. This is implemented by padding zeros (by one row/column) outside of the boundaries of feature maps, and performing bilinear sampling on the padded feature maps. 29 | 30 | 31 | * The efficiency of processing multiple images in a mini-batch is considerably improved. 32 | 33 | Both the previous and the updated operators follow the following computation pipeline (illustrated by a 3x3 deformable convolution with input data of NxCxHxW and output data of NxC'xHxW): 34 | 35 | for i in range(N/S): 36 | step 1 (slicing): slicing the input data at the batch dimension from i*S to (i+1)*S, input (NxCxHxW) -> sliced input (SxCxHxW) 37 | step 2 (deformable im2col): sliced input (SxCxHxW)+sliced offset (Sx18xHxW) -> column (Cx9xSxHxW) 38 | step 3 (MatMul&reshape): weight matrix (C'x 9C) * column (9CxSHW) -> temp sliced output (C'xSxHxW) -> sliced output (SxC'xHxW) 39 | step 4 (Merge): merge sliced output to form the whole output data (NxC'xHxW) 40 | end 41 | 42 | In the previous operator, S is fixed as 1. In the updated operator, S can be set by the *im2col_step* parameter, whose default value is min(N, 64). The updated operator is significantly faster than the existing one when the image batch size is large. 43 | 44 | ## License 45 | 46 | This repo is released under MIT license. 47 | -------------------------------------------------------------------------------- /test_modulated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import gradcheck 10 | 11 | from modules.modulated_dcn import ModulatedDeformConvPack 12 | from modules.modulated_dcn import DeformRoIPooling 13 | from modules.modulated_dcn import ModulatedDeformRoIPoolingPack 14 | 15 | deformable_groups = 1 16 | N, inC, inH, inW = 2, 2, 4, 4 17 | outC = 2 18 | kH, kW = 3, 3 19 | 20 | 21 | def example_dconv(): 22 | from modules.modulated_dcn import ModulatedDeformConv 23 | input = torch.randn(2, 64, 128, 128).cuda() 24 | # wrap all things (offset and mask) in DCN 25 | dcn = ModulatedDeformConvPack(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2, no_bias=True).cuda() 26 | output = dcn(input) 27 | targert = output.new(*output.size()) 28 | targert.data.uniform_(-0.01, 0.01) 29 | error = (targert - output).mean() 30 | error.backward() 31 | print(output.shape) 32 | 33 | def example_dpooling(): 34 | from modules.modulated_dcn import ModulatedDeformRoIPoolingPack 35 | input = torch.randn(2, 32, 64, 64).cuda() 36 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 37 | x = torch.randint(256, (20, 1)).cuda().float() 38 | y = torch.randint(256, (20, 1)).cuda().float() 39 | w = torch.randint(64, (20, 1)).cuda().float() 40 | h = torch.randint(64, (20, 1)).cuda().float() 41 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 42 | offset = torch.randn(20, 2, 7, 7).cuda() 43 | input.requires_grad = True 44 | offset.requires_grad = True 45 | 46 | # normal roi_align 47 | pooling = DeformRoIPooling(spatial_scale=1.0 / 4, 48 | pooled_size=7, 49 | output_dim=32, 50 | no_trans=True, 51 | group_size=1, 52 | trans_std=0.1).cuda() 53 | 54 | # deformable pooling 55 | dpooling = DeformRoIPooling(spatial_scale=1.0 / 4, 56 | pooled_size=7, 57 | output_dim=32, 58 | no_trans=False, 59 | group_size=1, 60 | trans_std=0.1).cuda() 61 | 62 | out = pooling(input, rois, offset) 63 | dout = dpooling(input, rois, offset) 64 | print(out.shape) 65 | print(dout.shape) 66 | 67 | target_out = out.new(*out.size()) 68 | target_out.data.uniform_(-0.01, 0.01) 69 | target_dout = dout.new(*dout.size()) 70 | target_dout.data.uniform_(-0.01, 0.01) 71 | e = (target_out - out).mean() 72 | e.backward() 73 | e = (target_dout - dout).mean() 74 | e.backward() 75 | 76 | def example_mdpooling(): 77 | from modules.modulated_dcn import ModulatedDeformRoIPoolingPack 78 | input = torch.randn(2, 32, 64, 64).cuda() 79 | input.requires_grad = True 80 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 81 | x = torch.randint(256, (20, 1)).cuda().float() 82 | y = torch.randint(256, (20, 1)).cuda().float() 83 | w = torch.randint(64, (20, 1)).cuda().float() 84 | h = torch.randint(64, (20, 1)).cuda().float() 85 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 86 | 87 | # mdformable pooling (V2) 88 | dpooling = ModulatedDeformRoIPoolingPack(spatial_scale=1.0 / 4, 89 | pooled_size=7, 90 | output_dim=32, 91 | no_trans=False, 92 | group_size=1, 93 | trans_std=0.1).cuda() 94 | 95 | for i in range(2): 96 | dout = dpooling(input, rois) 97 | target = dout.new(*dout.size()) 98 | target.data.uniform_(-0.1, 0.1) 99 | error = (target - dout).mean() 100 | error.backward() 101 | print(dout.shape) 102 | 103 | if __name__ == '__main__': 104 | 105 | example_dconv() 106 | example_dpooling() 107 | example_mdpooling() 108 | -------------------------------------------------------------------------------- /functions/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from _ext import deform_conv 6 | 7 | 8 | def deform_conv_function(input, 9 | offset, 10 | weight, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | deform_groups=1, 15 | im2col_step=64): 16 | 17 | if input is not None and input.dim() != 4: 18 | raise ValueError( 19 | "Expected 4D tensor as input, got {}D tensor instead.".format( 20 | input.dim())) 21 | 22 | f = DeformConvFunction( 23 | _pair(stride), _pair(padding), _pair(dilation), deform_groups, im2col_step) 24 | return f(input, offset, weight) 25 | 26 | 27 | class DeformConvFunction(Function): 28 | def __init__(self, stride, padding, dilation, deformable_groups=1, im2col_step=64): 29 | super(DeformConvFunction, self).__init__() 30 | self.stride = stride 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.deformable_groups = deformable_groups 34 | self.im2col_step = im2col_step 35 | 36 | def forward(self, input, offset, weight): 37 | self.save_for_backward(input, offset, weight) 38 | 39 | output = input.new(*self._output_size(input, weight)) 40 | 41 | self.bufs_ = [input.new(), input.new()] # columns, ones 42 | 43 | if not input.is_cuda: 44 | raise NotImplementedError 45 | else: 46 | if isinstance(input, torch.autograd.Variable): 47 | if not isinstance(input.data, torch.cuda.FloatTensor): 48 | raise NotImplementedError 49 | else: 50 | if not isinstance(input, torch.cuda.FloatTensor): 51 | raise NotImplementedError 52 | 53 | cur_im2col_step = min(self.im2col_step, input.shape[0]) 54 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 55 | deform_conv.deform_conv_forward_cuda( 56 | input, weight, offset, output, self.bufs_[0], self.bufs_[1], 57 | weight.size(3), weight.size(2), self.stride[1], self.stride[0], 58 | self.padding[1], self.padding[0], self.dilation[1], 59 | self.dilation[0], self.deformable_groups, cur_im2col_step) 60 | return output 61 | 62 | def backward(self, grad_output): 63 | input, offset, weight = self.saved_tensors 64 | 65 | grad_input = grad_offset = grad_weight = None 66 | 67 | if not grad_output.is_cuda: 68 | raise NotImplementedError 69 | else: 70 | if isinstance(grad_output, torch.autograd.Variable): 71 | if not isinstance(grad_output.data, torch.cuda.FloatTensor): 72 | raise NotImplementedError 73 | else: 74 | if not isinstance(grad_output, torch.cuda.FloatTensor): 75 | raise NotImplementedError 76 | 77 | cur_im2col_step = min(self.im2col_step, input.shape[0]) 78 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 79 | 80 | if self.needs_input_grad[0] or self.needs_input_grad[1]: 81 | grad_input = input.new(*input.size()).zero_() 82 | grad_offset = offset.new(*offset.size()).zero_() 83 | deform_conv.deform_conv_backward_input_cuda( 84 | input, offset, grad_output, grad_input, 85 | grad_offset, weight, self.bufs_[0], weight.size(3), 86 | weight.size(2), self.stride[1], self.stride[0], 87 | self.padding[1], self.padding[0], self.dilation[1], 88 | self.dilation[0], self.deformable_groups, cur_im2col_step) 89 | 90 | 91 | if self.needs_input_grad[2]: 92 | grad_weight = weight.new(*weight.size()).zero_() 93 | deform_conv.deform_conv_backward_parameters_cuda( 94 | input, offset, grad_output, 95 | grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3), 96 | weight.size(2), self.stride[1], self.stride[0], 97 | self.padding[1], self.padding[0], self.dilation[1], 98 | self.dilation[0], self.deformable_groups, 1, cur_im2col_step) 99 | 100 | return grad_input, grad_offset, grad_weight 101 | 102 | def _output_size(self, input, weight): 103 | channels = weight.size(0) 104 | 105 | output_size = (input.size(0), channels) 106 | for d in range(input.dim() - 2): 107 | in_size = input.size(d + 2) 108 | pad = self.padding[d] 109 | kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1 110 | stride = self.stride[d] 111 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) 112 | if not all(map(lambda s: s > 0, output_size)): 113 | raise ValueError( 114 | "convolution input is too small (output would be {})".format( 115 | 'x'.join(map(str, output_size)))) 116 | return output_size 117 | -------------------------------------------------------------------------------- /src/cuda/modulated_deform_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 3 | * 4 | * COPYRIGHT 5 | * 6 | * All contributions by the University of California: 7 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 8 | * All rights reserved. 9 | * 10 | * All other contributions: 11 | * Copyright (c) 2014-2017, the respective contributors 12 | * All rights reserved. 13 | * 14 | * Caffe uses a shared copyright model: each contributor holds copyright over 15 | * their contributions to Caffe. The project versioning records all such 16 | * contribution and copyright details. If a contributor wants to further mark 17 | * their specific copyright on a particular contribution, they should indicate 18 | * their copyright solely in the commit message of the change when it is 19 | * committed. 20 | * 21 | * LICENSE 22 | * 23 | * Redistribution and use in source and binary forms, with or without 24 | * modification, are permitted provided that the following conditions are met: 25 | * 26 | * 1. Redistributions of source code must retain the above copyright notice, this 27 | * list of conditions and the following disclaimer. 28 | * 2. Redistributions in binary form must reproduce the above copyright notice, 29 | * this list of conditions and the following disclaimer in the documentation 30 | * and/or other materials provided with the distribution. 31 | * 32 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 33 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 34 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 36 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 37 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 38 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 39 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 40 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 41 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | * 43 | * CONTRIBUTION AGREEMENT 44 | * 45 | * By contributing to the BVLC/caffe repository through pull-request, comment, 46 | * or otherwise, the contributor releases their content to the 47 | * license and copyright terms herein. 48 | * 49 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 50 | * 51 | * Copyright (c) 2018 Microsoft 52 | * Licensed under The MIT License [see LICENSE for details] 53 | * \file modulated_deformable_im2col.h 54 | * \brief Function definitions of converting an image to 55 | * column matrix based on kernel, padding, dilation, and offset. 56 | * These functions are mainly used in deformable convolution operators. 57 | * \ref: https://arxiv.org/abs/1811.11168 58 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 59 | */ 60 | 61 | /***************** Adapted by Charles Shang *********************/ 62 | 63 | #ifndef DCN_V2_IM2COL_CUDA 64 | #define DCN_V2_IM2COL_CUDA 65 | 66 | #ifdef __cplusplus 67 | extern "C" 68 | { 69 | #endif 70 | 71 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 72 | const float *data_im, const float *data_offset, const float *data_mask, 73 | const int batch_size, const int channels, const int height_im, const int width_im, 74 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 75 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 76 | const int dilation_h, const int dilation_w, 77 | const int deformable_group, float *data_col); 78 | 79 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 80 | const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 88 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 89 | const int batch_size, const int channels, const int height_im, const int width_im, 90 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 91 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 92 | const int dilation_h, const int dilation_w, 93 | const int deformable_group, 94 | float *grad_offset, float *grad_mask); 95 | 96 | #ifdef __cplusplus 97 | } 98 | #endif 99 | 100 | #endif -------------------------------------------------------------------------------- /functions/modulated_dcn_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | from torch.autograd import Function 8 | 9 | from _ext import modulated_dcn as _backend 10 | 11 | 12 | class ModulatedDeformConvFunction(Function): 13 | 14 | def __init__(self, stride, padding, dilation=1, deformable_groups=1): 15 | super(ModulatedDeformConvFunction, self).__init__() 16 | self.stride = stride 17 | self.padding = padding 18 | self.dilation = dilation 19 | self.deformable_groups = deformable_groups 20 | 21 | def forward(self, input, offset, mask, weight, bias): 22 | if not input.is_cuda: 23 | raise NotImplementedError 24 | if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: 25 | self.save_for_backward(input, offset, mask, weight, bias) 26 | output = input.new(*self._infer_shape(input, weight)) 27 | self._bufs = [input.new(), input.new()] 28 | _backend.modulated_deform_conv_cuda_forward(input, weight, 29 | bias, self._bufs[0], 30 | offset, mask, 31 | output, self._bufs[1], 32 | weight.shape[2], weight.shape[3], 33 | self.stride, self.stride, 34 | self.padding, self.padding, 35 | self.dilation, self.dilation, 36 | self.deformable_groups) 37 | return output 38 | 39 | def backward(self, grad_output): 40 | if not grad_output.is_cuda: 41 | raise NotImplementedError 42 | input, offset, mask, weight, bias = self.saved_tensors 43 | grad_input = input.new(*input.size()).zero_() 44 | grad_offset = offset.new(*offset.size()).zero_() 45 | grad_mask = mask.new(*mask.size()).zero_() 46 | grad_weight = weight.new(*weight.size()).zero_() 47 | grad_bias = bias.new(*bias.size()).zero_() 48 | _backend.modulated_deform_conv_cuda_backward(input, weight, 49 | bias, self._bufs[0], 50 | offset, mask, 51 | self._bufs[1], 52 | grad_input, grad_weight, 53 | grad_bias, grad_offset, 54 | grad_mask, grad_output, 55 | weight.shape[2], weight.shape[3], 56 | self.stride, self.stride, 57 | self.padding, self.padding, 58 | self.dilation, self.dilation, 59 | self.deformable_groups) 60 | 61 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias 62 | 63 | def _infer_shape(self, input, weight): 64 | n = input.size(0) 65 | channels_out = weight.size(0) 66 | height, width = input.shape[2:4] 67 | kernel_h, kernel_w = weight.shape[2:4] 68 | height_out = (height + 2 * self.padding - 69 | (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 70 | width_out = (width + 2 * self.padding - (self.dilation * 71 | (kernel_w - 1) + 1)) // self.stride + 1 72 | return (n, channels_out, height_out, width_out) 73 | 74 | 75 | class DeformRoIPoolingFunction(Function): 76 | 77 | def __init__(self, 78 | spatial_scale, 79 | pooled_size, 80 | output_dim, 81 | no_trans, 82 | group_size=1, 83 | part_size=None, 84 | sample_per_part=4, 85 | trans_std=.0): 86 | super(DeformRoIPoolingFunction, self).__init__() 87 | self.spatial_scale = spatial_scale 88 | self.pooled_size = pooled_size 89 | self.output_dim = output_dim 90 | self.no_trans = no_trans 91 | self.group_size = group_size 92 | self.part_size = pooled_size if part_size is None else part_size 93 | self.sample_per_part = sample_per_part 94 | self.trans_std = trans_std 95 | 96 | assert self.trans_std >= 0.0 and self.trans_std <= 1.0 97 | 98 | def forward(self, data, rois, offset): 99 | if not data.is_cuda: 100 | raise NotImplementedError 101 | 102 | output = data.new(*self._infer_shape(data, rois)) 103 | output_count = data.new(*self._infer_shape(data, rois)) 104 | _backend.deform_psroi_pooling_cuda_forward(data, rois, offset, 105 | output, output_count, 106 | self.no_trans, self.spatial_scale, 107 | self.output_dim, self.group_size, 108 | self.pooled_size, self.part_size, 109 | self.sample_per_part, self.trans_std) 110 | 111 | # if data.requires_grad or rois.requires_grad or offset.requires_grad: 112 | # self.save_for_backward(data, rois, offset, output_count) 113 | self.data = data 114 | self.rois = rois 115 | self.offset = offset 116 | self.output_count = output_count 117 | 118 | return output 119 | 120 | def backward(self, grad_output): 121 | if not grad_output.is_cuda: 122 | raise NotImplementedError 123 | 124 | # data, rois, offset, output_count = self.saved_tensors 125 | data = self.data 126 | rois = self.rois 127 | offset = self.offset 128 | output_count = self.output_count 129 | grad_input = data.new(*data.size()).zero_() 130 | grad_offset = offset.new(*offset.size()).zero_() 131 | 132 | _backend.deform_psroi_pooling_cuda_backward(grad_output, 133 | data, 134 | rois, 135 | offset, 136 | output_count, 137 | grad_input, 138 | grad_offset, 139 | self.no_trans, 140 | self.spatial_scale, 141 | self.output_dim, 142 | self.group_size, 143 | self.pooled_size, 144 | self.part_size, 145 | self.sample_per_part, 146 | self.trans_std) 147 | return grad_input, torch.zeros(rois.shape).cuda(), grad_offset 148 | 149 | def _infer_shape(self, data, rois): 150 | # _, c, h, w = data.shape[:4] 151 | c = data.shape[1] 152 | n = rois.shape[0] 153 | return (n, self.output_dim, self.pooled_size, self.pooled_size) 154 | -------------------------------------------------------------------------------- /modules/modulated_dcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import math 8 | from torch import nn 9 | from torch.nn.modules.utils import _pair 10 | 11 | from functions.modulated_dcn_func import ModulatedDeformConvFunction 12 | from functions.modulated_dcn_func import DeformRoIPoolingFunction 13 | 14 | class ModulatedDeformConv(nn.Module): 15 | 16 | def __init__(self, in_channels, out_channels, 17 | kernel_size, stride, padding, dilation=1, deformable_groups=1, no_bias=True): 18 | super(ModulatedDeformConv, self).__init__() 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.kernel_size = _pair(kernel_size) 22 | self.stride = stride 23 | self.padding = padding 24 | self.dilation = dilation 25 | self.deformable_groups = deformable_groups 26 | self.no_bias = no_bias 27 | 28 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 29 | self.bias = nn.Parameter(torch.zeros(out_channels)) 30 | self.reset_parameters() 31 | if self.no_bias: 32 | self.bias.requires_grad = False 33 | 34 | def reset_parameters(self): 35 | n = self.in_channels 36 | for k in self.kernel_size: 37 | n *= k 38 | stdv = 1. / math.sqrt(n) 39 | self.weight.data.uniform_(-stdv, stdv) 40 | self.bias.data.zero_() 41 | 42 | def forward(self, input, offset, mask): 43 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups) 44 | return func(input, offset, mask, self.weight, self.bias) 45 | 46 | 47 | class ModulatedDeformConvPack(ModulatedDeformConv): 48 | 49 | def __init__(self, in_channels, out_channels, 50 | kernel_size, stride, padding, 51 | dilation=1, deformable_groups=1, no_bias=False): 52 | super(ModulatedDeformConvPack, self).__init__(in_channels, out_channels, 53 | kernel_size, stride, padding, dilation, deformable_groups, no_bias) 54 | 55 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 56 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 57 | kernel_size=self.kernel_size, 58 | stride=(self.stride, self.stride), 59 | padding=(self.padding, self.padding), 60 | bias=True) 61 | self.init_offset() 62 | 63 | def init_offset(self): 64 | self.conv_offset_mask.weight.data.zero_() 65 | self.conv_offset_mask.bias.data.zero_() 66 | 67 | def forward(self, input): 68 | out = self.conv_offset_mask(input) 69 | o1, o2, mask = torch.chunk(out, 3, dim=1) 70 | offset = torch.cat((o1, o2), dim=1) 71 | mask = torch.sigmoid(mask) 72 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups) 73 | return func(input, offset, mask, self.weight, self.bias) 74 | 75 | 76 | class DeformRoIPooling(nn.Module): 77 | 78 | def __init__(self, 79 | spatial_scale, 80 | pooled_size, 81 | output_dim, 82 | no_trans, 83 | group_size=1, 84 | part_size=None, 85 | sample_per_part=4, 86 | trans_std=.0): 87 | super(DeformRoIPooling, self).__init__() 88 | self.spatial_scale = spatial_scale 89 | self.pooled_size = pooled_size 90 | self.output_dim = output_dim 91 | self.no_trans = no_trans 92 | self.group_size = group_size 93 | self.part_size = pooled_size if part_size is None else part_size 94 | self.sample_per_part = sample_per_part 95 | self.trans_std = trans_std 96 | self.func = DeformRoIPoolingFunction(self.spatial_scale, 97 | self.pooled_size, 98 | self.output_dim, 99 | self.no_trans, 100 | self.group_size, 101 | self.part_size, 102 | self.sample_per_part, 103 | self.trans_std) 104 | 105 | def forward(self, data, rois, offset): 106 | 107 | if self.no_trans: 108 | offset = data.new() 109 | return self.func(data, rois, offset) 110 | 111 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling): 112 | 113 | def __init__(self, 114 | spatial_scale, 115 | pooled_size, 116 | output_dim, 117 | no_trans, 118 | group_size=1, 119 | part_size=None, 120 | sample_per_part=4, 121 | trans_std=.0, 122 | deform_fc_dim=1024): 123 | super(ModulatedDeformRoIPoolingPack, self).__init__(spatial_scale, 124 | pooled_size, 125 | output_dim, 126 | no_trans, 127 | group_size, 128 | part_size, 129 | sample_per_part, 130 | trans_std) 131 | 132 | self.deform_fc_dim = deform_fc_dim 133 | 134 | if not no_trans: 135 | self.func_offset = DeformRoIPoolingFunction(self.spatial_scale, 136 | self.pooled_size, 137 | self.output_dim, 138 | True, 139 | self.group_size, 140 | self.part_size, 141 | self.sample_per_part, 142 | self.trans_std) 143 | self.offset_fc = nn.Sequential( 144 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), 145 | nn.ReLU(inplace=True), 146 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), 147 | nn.ReLU(inplace=True), 148 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 2) 149 | ) 150 | self.offset_fc[4].weight.data.zero_() 151 | self.offset_fc[4].bias.data.zero_() 152 | self.mask_fc = nn.Sequential( 153 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), 154 | nn.ReLU(inplace=True), 155 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 1), 156 | nn.Sigmoid() 157 | ) 158 | self.mask_fc[2].weight.data.zero_() 159 | self.mask_fc[2].bias.data.zero_() 160 | 161 | def forward(self, data, rois): 162 | if self.no_trans: 163 | offset = data.new() 164 | else: 165 | n = rois.shape[0] 166 | offset = data.new() 167 | x = self.func_offset(data, rois, offset) 168 | offset = self.offset_fc(x.view(n, -1)) 169 | offset = offset.view(n, 2, self.pooled_size, self.pooled_size) 170 | mask = self.mask_fc(x.view(n, -1)) 171 | mask = mask.view(n, 1, self.pooled_size, self.pooled_size) 172 | feat = self.func(data, rois, offset) * mask 173 | return feat 174 | return self.func(data, rois, offset) 175 | -------------------------------------------------------------------------------- /src/cuda/deform_psroi_pooling_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | #include "deform_psroi_pooling_cuda.h" 10 | #include 11 | #include 12 | #include 13 | 14 | #define CUDA_KERNEL_LOOP(i, n) \ 15 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 16 | i < (n); \ 17 | i += blockDim.x * gridDim.x) 18 | 19 | const int CUDA_NUM_THREADS = 1024; 20 | inline int GET_BLOCKS(const int N) 21 | { 22 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 23 | } 24 | 25 | __device__ float bilinear_interp( 26 | const float *data, 27 | const float x, 28 | const float y, 29 | const int width, 30 | const int height) 31 | { 32 | int x1 = floor(x); 33 | int x2 = ceil(x); 34 | int y1 = floor(y); 35 | int y2 = ceil(y); 36 | float dist_x = (float)(x - x1); 37 | float dist_y = (float)(y - y1); 38 | float value11 = data[y1 * width + x1]; 39 | float value12 = data[y2 * width + x1]; 40 | float value21 = data[y1 * width + x2]; 41 | float value22 = data[y2 * width + x2]; 42 | float value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; 43 | return value; 44 | } 45 | 46 | __global__ void DeformablePSROIPoolForwardKernel( 47 | const int count, 48 | const float *bottom_data, 49 | const float spatial_scale, 50 | const int channels, 51 | const int height, const int width, 52 | const int pooled_height, const int pooled_width, 53 | const float *bottom_rois, const float *bottom_trans, 54 | const int no_trans, 55 | const float trans_std, 56 | const int sample_per_part, 57 | const int output_dim, 58 | const int group_size, 59 | const int part_size, 60 | const int num_classes, 61 | const int channels_each_class, 62 | float *top_data, 63 | float *top_count) 64 | { 65 | CUDA_KERNEL_LOOP(index, count) 66 | { 67 | // The output is in order (n, ctop, ph, pw) 68 | int pw = index % pooled_width; 69 | int ph = (index / pooled_width) % pooled_height; 70 | int ctop = (index / pooled_width / pooled_height) % output_dim; 71 | int n = index / pooled_width / pooled_height / output_dim; 72 | 73 | // [start, end) interval for spatial sampling 74 | const float *offset_bottom_rois = bottom_rois + n * 5; 75 | int roi_batch_ind = offset_bottom_rois[0]; 76 | float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 77 | float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 78 | float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 79 | float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 80 | 81 | // Force too small ROIs to be 1x1 82 | float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 83 | float roi_height = max(roi_end_h - roi_start_h, 0.1); 84 | 85 | // Compute w and h at bottom 86 | float bin_size_h = roi_height / (float)(pooled_height); 87 | float bin_size_w = roi_width / (float)(pooled_width); 88 | 89 | float sub_bin_size_h = bin_size_h / (float)(sample_per_part); 90 | float sub_bin_size_w = bin_size_w / (float)(sample_per_part); 91 | 92 | int part_h = floor((float)(ph) / pooled_height * part_size); 93 | int part_w = floor((float)(pw) / pooled_width * part_size); 94 | int class_id = ctop / channels_each_class; 95 | float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 96 | float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 97 | 98 | float wstart = (float)(pw)*bin_size_w + roi_start_w; 99 | wstart += trans_x * roi_width; 100 | float hstart = (float)(ph)*bin_size_h + roi_start_h; 101 | hstart += trans_y * roi_height; 102 | 103 | float sum = 0; 104 | int count = 0; 105 | int gw = floor((float)(pw)*group_size / pooled_width); 106 | int gh = floor((float)(ph)*group_size / pooled_height); 107 | gw = min(max(gw, 0), group_size - 1); 108 | gh = min(max(gh, 0), group_size - 1); 109 | 110 | const float *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 111 | for (int ih = 0; ih < sample_per_part; ih++) 112 | { 113 | for (int iw = 0; iw < sample_per_part; iw++) 114 | { 115 | float w = wstart + iw * sub_bin_size_w; 116 | float h = hstart + ih * sub_bin_size_h; 117 | // bilinear interpolation 118 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 119 | { 120 | continue; 121 | } 122 | w = min(max(w, 0.), width - 1.); 123 | h = min(max(h, 0.), height - 1.); 124 | int c = (ctop * group_size + gh) * group_size + gw; 125 | float val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); 126 | sum += val; 127 | count++; 128 | } 129 | } 130 | top_data[index] = count == 0 ? (float)(0) : sum / count; 131 | top_count[index] = count; 132 | } 133 | } 134 | 135 | __global__ void DeformablePSROIPoolBackwardAccKernel( 136 | const int count, 137 | const float *top_diff, 138 | const float *top_count, 139 | const int num_rois, 140 | const float spatial_scale, 141 | const int channels, 142 | const int height, const int width, 143 | const int pooled_height, const int pooled_width, 144 | const int output_dim, 145 | float *bottom_data_diff, float *bottom_trans_diff, 146 | const float *bottom_data, 147 | const float *bottom_rois, 148 | const float *bottom_trans, 149 | const int no_trans, 150 | const float trans_std, 151 | const int sample_per_part, 152 | const int group_size, 153 | const int part_size, 154 | const int num_classes, 155 | const int channels_each_class) 156 | { 157 | CUDA_KERNEL_LOOP(index, count) 158 | { 159 | // The output is in order (n, ctop, ph, pw) 160 | int pw = index % pooled_width; 161 | int ph = (index / pooled_width) % pooled_height; 162 | int ctop = (index / pooled_width / pooled_height) % output_dim; 163 | int n = index / pooled_width / pooled_height / output_dim; 164 | 165 | // [start, end) interval for spatial sampling 166 | const float *offset_bottom_rois = bottom_rois + n * 5; 167 | int roi_batch_ind = offset_bottom_rois[0]; 168 | float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 169 | float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 170 | float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 171 | float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 172 | 173 | // Force too small ROIs to be 1x1 174 | float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 175 | float roi_height = max(roi_end_h - roi_start_h, 0.1); 176 | 177 | // Compute w and h at bottom 178 | float bin_size_h = roi_height / (float)(pooled_height); 179 | float bin_size_w = roi_width / (float)(pooled_width); 180 | 181 | float sub_bin_size_h = bin_size_h / (float)(sample_per_part); 182 | float sub_bin_size_w = bin_size_w / (float)(sample_per_part); 183 | 184 | int part_h = floor((float)(ph) / pooled_height * part_size); 185 | int part_w = floor((float)(pw) / pooled_width * part_size); 186 | int class_id = ctop / channels_each_class; 187 | float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 188 | float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 189 | 190 | float wstart = (float)(pw)*bin_size_w + roi_start_w; 191 | wstart += trans_x * roi_width; 192 | float hstart = (float)(ph)*bin_size_h + roi_start_h; 193 | hstart += trans_y * roi_height; 194 | 195 | if (top_count[index] <= 0) 196 | { 197 | continue; 198 | } 199 | float diff_val = top_diff[index] / top_count[index]; 200 | const float *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 201 | float *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 202 | int gw = floor((float)(pw)*group_size / pooled_width); 203 | int gh = floor((float)(ph)*group_size / pooled_height); 204 | gw = min(max(gw, 0), group_size - 1); 205 | gh = min(max(gh, 0), group_size - 1); 206 | 207 | for (int ih = 0; ih < sample_per_part; ih++) 208 | { 209 | for (int iw = 0; iw < sample_per_part; iw++) 210 | { 211 | float w = wstart + iw * sub_bin_size_w; 212 | float h = hstart + ih * sub_bin_size_h; 213 | // bilinear interpolation 214 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 215 | { 216 | continue; 217 | } 218 | w = min(max(w, 0.), width - 1.); 219 | h = min(max(h, 0.), height - 1.); 220 | int c = (ctop * group_size + gh) * group_size + gw; 221 | // backward on feature 222 | int x0 = floor(w); 223 | int x1 = ceil(w); 224 | int y0 = floor(h); 225 | int y1 = ceil(h); 226 | float dist_x = w - x0, dist_y = h - y0; 227 | float q00 = (1 - dist_x) * (1 - dist_y); 228 | float q01 = (1 - dist_x) * dist_y; 229 | float q10 = dist_x * (1 - dist_y); 230 | float q11 = dist_x * dist_y; 231 | int bottom_index_base = c * height * width; 232 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 233 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 234 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 235 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); 236 | 237 | if (no_trans) 238 | { 239 | continue; 240 | } 241 | float U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 242 | float U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 243 | float U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 244 | float U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 245 | float diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; 246 | diff_x *= roi_width; 247 | float diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; 248 | diff_y *= roi_height; 249 | 250 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); 251 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); 252 | } 253 | } 254 | } 255 | } 256 | 257 | void DeformablePSROIPoolForward(cudaStream_t stream, 258 | const float *data, 259 | const float *bbox, 260 | const float *trans, 261 | float *out, 262 | float *top_count, 263 | const int batch, 264 | const int channels, 265 | const int height, 266 | const int width, 267 | const int num_bbox, 268 | const int channels_trans, 269 | const int no_trans, 270 | const float spatial_scale, 271 | const int output_dim, 272 | const int group_size, 273 | const int pooled_size, 274 | const int part_size, 275 | const int sample_per_part, 276 | const float trans_std) 277 | { 278 | 279 | const float *bottom_data = data; 280 | const float *bottom_rois = bbox; 281 | const float *bottom_trans = no_trans ? NULL : trans; 282 | float *top_data = out; 283 | float *top_count_data = top_count; 284 | 285 | const int pooled_height = pooled_size; 286 | const int pooled_width = pooled_size; 287 | const int count = num_bbox * output_dim * pooled_height * pooled_width; 288 | const int num_classes = no_trans ? 1 : channels_trans / 2; 289 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 290 | 291 | DeformablePSROIPoolForwardKernel<<>>( 292 | count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, 293 | bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, 294 | group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); 295 | 296 | cudaError_t err = cudaGetLastError(); 297 | if (err != cudaSuccess) 298 | { 299 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); 300 | } 301 | } 302 | 303 | void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, 304 | const float *out_grad, 305 | const float *data, 306 | const float *bbox, 307 | const float *trans, 308 | const float *top_count, 309 | float *in_grad, 310 | float *trans_grad, 311 | const int batch, 312 | const int channels, 313 | const int height, 314 | const int width, 315 | const int num_bbox, 316 | const int channels_trans, 317 | const int no_trans, 318 | const float spatial_scale, 319 | const int output_dim, 320 | const int group_size, 321 | const int pooled_size, 322 | const int part_size, 323 | const int sample_per_part, 324 | const float trans_std) 325 | { 326 | // LOG(INFO) << "DeformablePSROIPoolBackward"; 327 | const float *top_diff = out_grad; 328 | const float *bottom_data = data; 329 | const float *bottom_rois = bbox; 330 | const float *bottom_trans = no_trans ? NULL : trans; 331 | float *bottom_data_diff = in_grad; 332 | float *bottom_trans_diff = no_trans ? NULL : trans_grad; 333 | const float *top_count_data = top_count; 334 | 335 | const int num_rois = num_bbox; 336 | const int pooled_height = pooled_size; 337 | const int pooled_width = pooled_size; 338 | const int count = num_bbox * output_dim * pooled_height * pooled_width; 339 | const int num_classes = no_trans ? 1 : channels_trans / 2; 340 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 341 | 342 | DeformablePSROIPoolBackwardAccKernel<<>>( 343 | count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, 344 | pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, 345 | bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, 346 | group_size, part_size, num_classes, channels_each_class); 347 | 348 | cudaError_t err = cudaGetLastError(); 349 | if (err != cudaSuccess) 350 | { 351 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); 352 | } 353 | } -------------------------------------------------------------------------------- /src/modulated_dcn_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/modulated_deform_im2col_cuda.h" 3 | #include "cuda/deform_psroi_pooling_cuda.h" 4 | 5 | extern THCState *state; 6 | 7 | // author: Charles Shang 8 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 9 | 10 | void modulated_deform_conv_cuda_forward(THCudaTensor *input, THCudaTensor *weight, 11 | THCudaTensor *bias, THCudaTensor *ones, 12 | THCudaTensor *offset, THCudaTensor *mask, 13 | THCudaTensor *output, THCudaTensor *columns, 14 | int kernel_h, int kernel_w, 15 | const int stride_h, const int stride_w, 16 | const int pad_h, const int pad_w, 17 | const int dilation_h, const int dilation_w, 18 | const int deformable_group) 19 | { 20 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); 21 | THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); 22 | THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); 23 | 24 | const int batch = THCudaTensor_size(state, input, 0); 25 | const int channels = THCudaTensor_size(state, input, 1); 26 | const int height = THCudaTensor_size(state, input, 2); 27 | const int width = THCudaTensor_size(state, input, 3); 28 | 29 | const int channels_out = THCudaTensor_size(state, weight, 0); 30 | const int channels_kernel = THCudaTensor_size(state, weight, 1); 31 | const int kernel_h_ = THCudaTensor_size(state, weight, 2); 32 | const int kernel_w_ = THCudaTensor_size(state, weight, 3); 33 | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) 34 | THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", 35 | kernel_h_, kernel_w, kernel_h_, kernel_w_); 36 | if (channels != channels_kernel) 37 | THError("Input shape and kernel channels wont match: (%d vs %d).", 38 | channels, channels_kernel); 39 | 40 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 41 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 42 | 43 | if (THCudaTensor_nDimension(state, ones) != 2 || 44 | THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) 45 | { 46 | // Resize plane and fill with ones... 47 | THCudaTensor_resize2d(state, ones, height_out, width_out); 48 | THCudaTensor_fill(state, ones, 1); 49 | } 50 | 51 | // resize output 52 | THCudaTensor_resize4d(state, output, batch, channels_out, height_out, width_out); 53 | // resize temporary columns 54 | THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); 55 | 56 | THCudaTensor *input_n = THCudaTensor_new(state); 57 | THCudaTensor *offset_n = THCudaTensor_new(state); 58 | THCudaTensor *mask_n = THCudaTensor_new(state); 59 | THCudaTensor *output_n = THCudaTensor_new(state); 60 | 61 | for (int b = 0; b < batch; b++) 62 | { 63 | THCudaTensor_select(state, input_n, input, 0, b); 64 | THCudaTensor_select(state, offset_n, offset, 0, b); 65 | THCudaTensor_select(state, mask_n, mask, 0, b); 66 | THCudaTensor_select(state, output_n, output, 0, b); 67 | 68 | // Do Bias first: 69 | // M,N,K are dims of matrix A and B 70 | // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) 71 | // (N x 1) (1 x M) 72 | long m_ = channels_out; 73 | long n_ = height_out * width_out; 74 | long k_ = 1; 75 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 76 | THCudaTensor_data(state, ones), k_, 77 | THCudaTensor_data(state, bias), k_, 0.0f, 78 | THCudaTensor_data(state, output_n), n_); 79 | 80 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 81 | THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), 82 | THCudaTensor_data(state, mask_n), 83 | 1, channels, height, width, 84 | height_out, width_out, kernel_h, kernel_w, 85 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 86 | deformable_group, THCudaTensor_data(state, columns)); 87 | 88 | //(k * m) x (m * n) 89 | // Y = WC 90 | long m = channels_out; 91 | long n = height_out * width_out; 92 | long k = channels * kernel_h * kernel_w; 93 | THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, 94 | THCudaTensor_data(state, columns), n, 95 | THCudaTensor_data(state, weight), k, 1.0f, 96 | THCudaTensor_data(state, output_n), n); 97 | } 98 | THCudaTensor_free(state, input_n); 99 | THCudaTensor_free(state, offset_n); 100 | THCudaTensor_free(state, mask_n); 101 | THCudaTensor_free(state, output_n); 102 | } 103 | 104 | void modulated_deform_conv_cuda_backward(THCudaTensor *input, THCudaTensor *weight, 105 | THCudaTensor *bias, THCudaTensor *ones, 106 | THCudaTensor *offset, THCudaTensor *mask, 107 | THCudaTensor *columns, 108 | THCudaTensor *grad_input, THCudaTensor *grad_weight, 109 | THCudaTensor *grad_bias, THCudaTensor *grad_offset, 110 | THCudaTensor *grad_mask, THCudaTensor *grad_output, 111 | int kernel_h, int kernel_w, 112 | int stride_h, int stride_w, 113 | int pad_h, int pad_w, 114 | int dilation_h, int dilation_w, 115 | int deformable_group) 116 | { 117 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, 118 | grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); 119 | THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); 120 | THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); 121 | 122 | const int batch = THCudaTensor_size(state, input, 0); 123 | const int channels = THCudaTensor_size(state, input, 1); 124 | const int height = THCudaTensor_size(state, input, 2); 125 | const int width = THCudaTensor_size(state, input, 3); 126 | 127 | const int channels_out = THCudaTensor_size(state, weight, 0); 128 | const int channels_kernel = THCudaTensor_size(state, weight, 1); 129 | const int kernel_h_ = THCudaTensor_size(state, weight, 2); 130 | const int kernel_w_ = THCudaTensor_size(state, weight, 3); 131 | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) 132 | THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", 133 | kernel_h_, kernel_w, kernel_h_, kernel_w_); 134 | if (channels != channels_kernel) 135 | THError("Input shape and kernel channels wont match: (%d vs %d).", 136 | channels, channels_kernel); 137 | 138 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 139 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 140 | 141 | if (THCudaTensor_nDimension(state, ones) != 2 || 142 | THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) 143 | { 144 | // Resize plane and fill with ones... 145 | THCudaTensor_resize2d(state, ones, height_out, width_out); 146 | THCudaTensor_fill(state, ones, 1.0f); 147 | } 148 | 149 | THCudaTensor_resize4d(state, grad_input, batch, channels, height, width); 150 | THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); 151 | 152 | THCudaTensor *input_n = THCudaTensor_new(state); 153 | THCudaTensor *offset_n = THCudaTensor_new(state); 154 | THCudaTensor *mask_n = THCudaTensor_new(state); 155 | 156 | THCudaTensor *grad_output_n = THCudaTensor_new(state); 157 | THCudaTensor *grad_input_n = THCudaTensor_new(state); 158 | THCudaTensor *grad_offset_n = THCudaTensor_new(state); 159 | THCudaTensor *grad_mask_n = THCudaTensor_new(state); 160 | 161 | for (int b = 0; b < batch; b++) 162 | { 163 | THCudaTensor_select(state, input_n, input, 0, b); 164 | THCudaTensor_select(state, offset_n, offset, 0, b); 165 | THCudaTensor_select(state, mask_n, mask, 0, b); 166 | THCudaTensor_select(state, grad_output_n, grad_output, 0, b); 167 | THCudaTensor_select(state, grad_input_n, grad_input, 0, b); 168 | THCudaTensor_select(state, grad_offset_n, grad_offset, 0, b); 169 | THCudaTensor_select(state, grad_mask_n, grad_mask, 0, b); 170 | 171 | long m = channels * kernel_h * kernel_w; 172 | long n = height_out * width_out; 173 | long k = channels_out; 174 | 175 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 176 | THCudaTensor_data(state, grad_output_n), n, 177 | THCudaTensor_data(state, weight), m, 0.0f, 178 | THCudaTensor_data(state, columns), n); 179 | 180 | // gradient w.r.t. input coordinate data 181 | modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), 182 | THCudaTensor_data(state, columns), 183 | THCudaTensor_data(state, input_n), 184 | THCudaTensor_data(state, offset_n), 185 | THCudaTensor_data(state, mask_n), 186 | 1, channels, height, width, 187 | height_out, width_out, kernel_h, kernel_w, 188 | pad_h, pad_w, stride_h, stride_w, 189 | dilation_h, dilation_w, deformable_group, 190 | THCudaTensor_data(state, grad_offset_n), 191 | THCudaTensor_data(state, grad_mask_n)); 192 | // gradient w.r.t. input data 193 | modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), 194 | THCudaTensor_data(state, columns), 195 | THCudaTensor_data(state, offset_n), 196 | THCudaTensor_data(state, mask_n), 197 | 1, channels, height, width, 198 | height_out, width_out, kernel_h, kernel_w, 199 | pad_h, pad_w, stride_h, stride_w, 200 | dilation_h, dilation_w, deformable_group, 201 | THCudaTensor_data(state, grad_input_n)); 202 | 203 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group 204 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 205 | THCudaTensor_data(state, input_n), 206 | THCudaTensor_data(state, offset_n), 207 | THCudaTensor_data(state, mask_n), 208 | 1, channels, height, width, 209 | height_out, width_out, kernel_h, kernel_w, 210 | pad_h, pad_w, stride_h, stride_w, 211 | dilation_h, dilation_w, deformable_group, 212 | THCudaTensor_data(state, columns)); 213 | long m_ = channels_out; 214 | long n_ = channels * kernel_h * kernel_w; 215 | long k_ = height_out * width_out; 216 | 217 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 218 | THCudaTensor_data(state, columns), k_, 219 | THCudaTensor_data(state, grad_output_n), k_, 1.0f, 220 | THCudaTensor_data(state, grad_weight), n_); 221 | 222 | // gradient w.r.t. bias 223 | // long m_ = channels_out; 224 | // long k__ = height_out * width_out; 225 | THCudaBlas_Sgemv(state, 226 | 't', 227 | k_, m_, 1.0f, 228 | THCudaTensor_data(state, grad_output_n), k_, 229 | THCudaTensor_data(state, ones), 1, 1.0f, 230 | THCudaTensor_data(state, grad_bias), 1); 231 | } 232 | 233 | THCudaTensor_free(state, input_n); 234 | THCudaTensor_free(state, offset_n); 235 | THCudaTensor_free(state, mask_n); 236 | 237 | THCudaTensor_free(state, grad_output_n); 238 | THCudaTensor_free(state, grad_input_n); 239 | THCudaTensor_free(state, grad_offset_n); 240 | THCudaTensor_free(state, grad_mask_n); 241 | } 242 | 243 | void deform_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, 244 | THCudaTensor * trans, 245 | THCudaTensor * out, THCudaTensor * top_count, 246 | const int no_trans, 247 | const float spatial_scale, 248 | const int output_dim, 249 | const int group_size, 250 | const int pooled_size, 251 | const int part_size, 252 | const int sample_per_part, 253 | const float trans_std) 254 | { 255 | THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); 256 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, bbox, trans, out, top_count)); 257 | 258 | const int batch = THCudaTensor_size(state, input, 0); 259 | const int channels = THCudaTensor_size(state, input, 1); 260 | const int height = THCudaTensor_size(state, input, 2); 261 | const int width = THCudaTensor_size(state, input, 3); 262 | const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); 263 | 264 | const int num_bbox = THCudaTensor_size(state, bbox, 0); 265 | if (num_bbox != THCudaTensor_size(state, out, 0)) 266 | THError("Output shape and bbox number wont match: (%d vs %d).", 267 | THCudaTensor_size(state, out, 0), num_bbox); 268 | 269 | DeformablePSROIPoolForward(THCState_getCurrentStream(state), 270 | THCudaTensor_data(state, input), 271 | THCudaTensor_data(state, bbox), 272 | THCudaTensor_data(state, trans), 273 | THCudaTensor_data(state, out), 274 | THCudaTensor_data(state, top_count), 275 | batch, channels, height, width, 276 | num_bbox, 277 | channels_trans, 278 | no_trans, 279 | spatial_scale, 280 | output_dim, 281 | group_size, 282 | pooled_size, 283 | part_size, 284 | sample_per_part, 285 | trans_std); 286 | } 287 | 288 | void deform_psroi_pooling_cuda_backward(THCudaTensor * out_grad, 289 | THCudaTensor * input, THCudaTensor * bbox, 290 | THCudaTensor * trans, THCudaTensor * top_count, 291 | THCudaTensor * input_grad, THCudaTensor * trans_grad, 292 | const int no_trans, 293 | const float spatial_scale, 294 | const int output_dim, 295 | const int group_size, 296 | const int pooled_size, 297 | const int part_size, 298 | const int sample_per_part, 299 | const float trans_std) 300 | { 301 | THArgCheck(THCudaTensor_isContiguous(state, out_grad), 0, "out_grad tensor has to be contiguous"); 302 | THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); 303 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 7, input, bbox, trans, out_grad, top_count, 304 | input_grad, trans_grad)); 305 | 306 | const int batch = THCudaTensor_size(state, input, 0); 307 | const int channels = THCudaTensor_size(state, input, 1); 308 | const int height = THCudaTensor_size(state, input, 2); 309 | const int width = THCudaTensor_size(state, input, 3); 310 | const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); 311 | 312 | const int num_bbox = THCudaTensor_size(state, bbox, 0); 313 | if (num_bbox != THCudaTensor_size(state, out_grad, 0)) 314 | THError("Output shape and bbox number wont match: (%d vs %d).", 315 | THCudaTensor_size(state, out_grad, 0), num_bbox); 316 | 317 | DeformablePSROIPoolBackwardAcc(THCState_getCurrentStream(state), 318 | THCudaTensor_data(state, out_grad), 319 | THCudaTensor_data(state, input), 320 | THCudaTensor_data(state, bbox), 321 | THCudaTensor_data(state, trans), 322 | THCudaTensor_data(state, top_count), 323 | THCudaTensor_data(state, input_grad), 324 | THCudaTensor_data(state, trans_grad), 325 | batch, channels, height, width, num_bbox, 326 | channels_trans, 327 | no_trans, 328 | spatial_scale, 329 | output_dim, 330 | group_size, 331 | pooled_size, 332 | part_size, 333 | sample_per_part, 334 | trans_std); 335 | } -------------------------------------------------------------------------------- /src/cuda/modulated_deform_im2col_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "modulated_deform_im2col_cuda.h" 2 | #include 3 | #include 4 | #include 5 | 6 | #define CUDA_KERNEL_LOOP(i, n) \ 7 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 8 | i < (n); \ 9 | i += blockDim.x * gridDim.x) 10 | 11 | const int CUDA_NUM_THREADS = 1024; 12 | inline int GET_BLOCKS(const int N) 13 | { 14 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 15 | } 16 | 17 | 18 | __device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, 19 | const int height, const int width, float h, float w) 20 | { 21 | int h_low = floor(h); 22 | int w_low = floor(w); 23 | int h_high = h_low + 1; 24 | int w_high = w_low + 1; 25 | 26 | float lh = h - h_low; 27 | float lw = w - w_low; 28 | float hh = 1 - lh, hw = 1 - lw; 29 | 30 | float v1 = 0; 31 | if (h_low >= 0 && w_low >= 0) 32 | v1 = bottom_data[h_low * data_width + w_low]; 33 | float v2 = 0; 34 | if (h_low >= 0 && w_high <= width - 1) 35 | v2 = bottom_data[h_low * data_width + w_high]; 36 | float v3 = 0; 37 | if (h_high <= height - 1 && w_low >= 0) 38 | v3 = bottom_data[h_high * data_width + w_low]; 39 | float v4 = 0; 40 | if (h_high <= height - 1 && w_high <= width - 1) 41 | v4 = bottom_data[h_high * data_width + w_high]; 42 | 43 | float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 44 | 45 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 46 | return val; 47 | } 48 | 49 | __device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, 50 | const int h, const int w, const int height, const int width) 51 | { 52 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 53 | { 54 | //empty 55 | return 0; 56 | } 57 | 58 | int argmax_h_low = floor(argmax_h); 59 | int argmax_w_low = floor(argmax_w); 60 | int argmax_h_high = argmax_h_low + 1; 61 | int argmax_w_high = argmax_w_low + 1; 62 | 63 | float weight = 0; 64 | if (h == argmax_h_low && w == argmax_w_low) 65 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 66 | if (h == argmax_h_low && w == argmax_w_high) 67 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 68 | if (h == argmax_h_high && w == argmax_w_low) 69 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 70 | if (h == argmax_h_high && w == argmax_w_high) 71 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 72 | return weight; 73 | } 74 | 75 | __device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, 76 | const int height, const int width, const float *im_data, 77 | const int data_width, const int bp_dir) 78 | { 79 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 80 | { 81 | //empty 82 | return 0; 83 | } 84 | 85 | int argmax_h_low = floor(argmax_h); 86 | int argmax_w_low = floor(argmax_w); 87 | int argmax_h_high = argmax_h_low + 1; 88 | int argmax_w_high = argmax_w_low + 1; 89 | 90 | float weight = 0; 91 | 92 | if (bp_dir == 0) 93 | { 94 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 95 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 96 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 97 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 98 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 99 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 100 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 101 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 102 | } 103 | else if (bp_dir == 1) 104 | { 105 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 106 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 107 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 108 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 109 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 110 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 111 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 112 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 113 | } 114 | 115 | return weight; 116 | } 117 | 118 | __global__ void modulated_deformable_im2col_gpu_kernel(const int n, 119 | const float *data_im, const float *data_offset, const float *data_mask, 120 | const int height, const int width, const int kernel_h, const int kernel_w, 121 | const int pad_h, const int pad_w, 122 | const int stride_h, const int stride_w, 123 | const int dilation_h, const int dilation_w, 124 | const int channel_per_deformable_group, 125 | const int batch_size, const int num_channels, const int deformable_group, 126 | const int height_col, const int width_col, 127 | float *data_col) 128 | { 129 | CUDA_KERNEL_LOOP(index, n) 130 | { 131 | // index index of output matrix 132 | const int w_col = index % width_col; 133 | const int h_col = (index / width_col) % height_col; 134 | const int b_col = (index / width_col / height_col) % batch_size; 135 | const int c_im = (index / width_col / height_col) / batch_size; 136 | const int c_col = c_im * kernel_h * kernel_w; 137 | 138 | // compute deformable group index 139 | const int deformable_group_index = c_im / channel_per_deformable_group; 140 | 141 | const int h_in = h_col * stride_h - pad_h; 142 | const int w_in = w_col * stride_w - pad_w; 143 | 144 | float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; 145 | //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; 146 | const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; 147 | const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 148 | 149 | const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 150 | 151 | for (int i = 0; i < kernel_h; ++i) 152 | { 153 | for (int j = 0; j < kernel_w; ++j) 154 | { 155 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 156 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; 157 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; 158 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 159 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 160 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 161 | float val = static_cast(0); 162 | const float h_im = h_in + i * dilation_h + offset_h; 163 | const float w_im = w_in + j * dilation_w + offset_w; 164 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 165 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) 166 | { 167 | //const float map_h = i * dilation_h + offset_h; 168 | //const float map_w = j * dilation_w + offset_w; 169 | //const int cur_height = height - h_in; 170 | //const int cur_width = width - w_in; 171 | //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 172 | val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); 173 | } 174 | *data_col_ptr = val * mask; 175 | data_col_ptr += batch_size * height_col * width_col; 176 | //data_col_ptr += height_col * width_col; 177 | } 178 | } 179 | } 180 | } 181 | 182 | __global__ void modulated_deformable_col2im_gpu_kernel(const int n, 183 | const float *data_col, const float *data_offset, const float *data_mask, 184 | const int channels, const int height, const int width, 185 | const int kernel_h, const int kernel_w, 186 | const int pad_h, const int pad_w, 187 | const int stride_h, const int stride_w, 188 | const int dilation_h, const int dilation_w, 189 | const int channel_per_deformable_group, 190 | const int batch_size, const int deformable_group, 191 | const int height_col, const int width_col, 192 | float *grad_im) 193 | { 194 | CUDA_KERNEL_LOOP(index, n) 195 | { 196 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 197 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 198 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 199 | // compute the start and end of the output 200 | 201 | const int deformable_group_index = c / channel_per_deformable_group; 202 | 203 | int w_out = index % width_col; 204 | int h_out = (index / width_col) % height_col; 205 | int b = (index / width_col / height_col) % batch_size; 206 | int w_in = w_out * stride_w - pad_w; 207 | int h_in = h_out * stride_h - pad_h; 208 | 209 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 210 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 211 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 212 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 213 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; 214 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 215 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 216 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 217 | const float cur_inv_h_data = h_in + i * dilation_h + offset_h; 218 | const float cur_inv_w_data = w_in + j * dilation_w + offset_w; 219 | 220 | const float cur_top_grad = data_col[index] * mask; 221 | const int cur_h = (int)cur_inv_h_data; 222 | const int cur_w = (int)cur_inv_w_data; 223 | for (int dy = -2; dy <= 2; dy++) 224 | { 225 | for (int dx = -2; dx <= 2; dx++) 226 | { 227 | if (cur_h + dy >= 0 && cur_h + dy < height && 228 | cur_w + dx >= 0 && cur_w + dx < width && 229 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 230 | abs(cur_inv_w_data - (cur_w + dx)) < 1) 231 | { 232 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 233 | float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 234 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 235 | } 236 | } 237 | } 238 | } 239 | } 240 | 241 | __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, 242 | const float *data_col, const float *data_im, 243 | const float *data_offset, const float *data_mask, 244 | const int channels, const int height, const int width, 245 | const int kernel_h, const int kernel_w, 246 | const int pad_h, const int pad_w, 247 | const int stride_h, const int stride_w, 248 | const int dilation_h, const int dilation_w, 249 | const int channel_per_deformable_group, 250 | const int batch_size, const int offset_channels, const int deformable_group, 251 | const int height_col, const int width_col, 252 | float *grad_offset, float *grad_mask) 253 | { 254 | CUDA_KERNEL_LOOP(index, n) 255 | { 256 | float val = 0, mval = 0; 257 | int w = index % width_col; 258 | int h = (index / width_col) % height_col; 259 | int c = (index / width_col / height_col) % offset_channels; 260 | int b = (index / width_col / height_col) / offset_channels; 261 | // compute the start and end of the output 262 | 263 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 264 | const int col_step = kernel_h * kernel_w; 265 | int cnt = 0; 266 | const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; 267 | const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; 268 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 269 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 270 | 271 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 272 | 273 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) 274 | { 275 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 276 | const int bp_dir = offset_c % 2; 277 | 278 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 279 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 280 | int w_out = col_pos % width_col; 281 | int h_out = (col_pos / width_col) % height_col; 282 | int w_in = w_out * stride_w - pad_w; 283 | int h_in = h_out * stride_h - pad_h; 284 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 285 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 286 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); 287 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 288 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 289 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 290 | float inv_h = h_in + i * dilation_h + offset_h; 291 | float inv_w = w_in + j * dilation_w + offset_w; 292 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) 293 | { 294 | inv_h = inv_w = -2; 295 | } 296 | else 297 | { 298 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); 299 | } 300 | const float weight = dmcn_get_coordinate_weight( 301 | inv_h, inv_w, 302 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 303 | val += weight * data_col_ptr[col_pos] * mask; 304 | cnt += 1; 305 | } 306 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val); 307 | grad_offset[index] = val; 308 | if (offset_c % 2 == 0) 309 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); 310 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; 311 | } 312 | } 313 | 314 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 315 | const float* data_im, const float* data_offset, const float* data_mask, 316 | const int batch_size, const int channels, const int height_im, const int width_im, 317 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 318 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 319 | const int dilation_h, const int dilation_w, 320 | const int deformable_group, float* data_col) { 321 | // num_axes should be smaller than block size 322 | const int channel_per_deformable_group = channels / deformable_group; 323 | const int num_kernels = channels * batch_size * height_col * width_col; 324 | modulated_deformable_im2col_gpu_kernel 325 | <<>>( 327 | num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, 328 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, 329 | batch_size, channels, deformable_group, height_col, width_col, data_col); 330 | 331 | cudaError_t err = cudaGetLastError(); 332 | if (err != cudaSuccess) 333 | { 334 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); 335 | } 336 | 337 | } 338 | 339 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 340 | const float* data_col, const float* data_offset, const float* data_mask, 341 | const int batch_size, const int channels, const int height_im, const int width_im, 342 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 343 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 344 | const int dilation_h, const int dilation_w, 345 | const int deformable_group, float* grad_im){ 346 | 347 | const int channel_per_deformable_group = channels / deformable_group; 348 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; 349 | modulated_deformable_col2im_gpu_kernel 350 | <<>>( 352 | num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, 353 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, 354 | dilation_h, dilation_w, channel_per_deformable_group, 355 | batch_size, deformable_group, height_col, width_col, grad_im); 356 | cudaError_t err = cudaGetLastError(); 357 | if (err != cudaSuccess) 358 | { 359 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); 360 | } 361 | 362 | } 363 | 364 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 365 | const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, 366 | const int batch_size, const int channels, const int height_im, const int width_im, 367 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 368 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 369 | const int dilation_h, const int dilation_w, 370 | const int deformable_group, 371 | float* grad_offset, float* grad_mask) { 372 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; 373 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; 374 | modulated_deformable_col2im_coord_gpu_kernel 375 | <<>>( 377 | num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, 378 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 379 | dilation_h, dilation_w, channel_per_deformable_group, 380 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, 381 | grad_offset, grad_mask); 382 | cudaError_t err = cudaGetLastError(); 383 | if (err != cudaSuccess) 384 | { 385 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); 386 | } 387 | } -------------------------------------------------------------------------------- /src/deform_conv_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "deform_conv_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | void shape_check(THCState *state, THCudaTensor *input, THCudaTensor *offset, 8 | THCudaTensor *gradOutput, THCudaTensor *weight, int kH, int kW, 9 | int dH, int dW, int padH, int padW, int dilationH, 10 | int dilationW, int deformable_group) { 11 | 12 | // THArgCheck(weight->nDimension == 4, 5, 13 | // "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " 14 | // "but got: %s", 15 | // weight->nDimension); 16 | THArgCheck(THCudaTensor_nDimension(state, weight) == 4, 5, 17 | "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " 18 | "but got: %s", 19 | THCudaTensor_nDimension(state, weight)); 20 | 21 | THArgCheck(THCudaTensor_isContiguous(state, weight), 5, 22 | "weight tensor has to be contiguous"); 23 | 24 | THArgCheck(kW > 0 && kH > 0, 9, 25 | "kernel size should be greater than zero, but got kH: %d kW: %d", 26 | kH, kW); 27 | 28 | // THArgCheck((weight->size[2] == kH && weight->size[3] == kW), 9, 29 | // "kernel size should be consistent with weight, ", 30 | // "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, 31 | // kW, weight->size[2], weight->size[3]); 32 | THArgCheck((THCudaTensor_size(state, weight, 2) == kH && 33 | THCudaTensor_size(state, weight, 3) == kW), 9, 34 | "kernel size should be consistent with weight, ", 35 | "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, 36 | kW, THCudaTensor_size(state, weight, 2), THCudaTensor_size(state, weight, 3)); 37 | 38 | 39 | THArgCheck(dW > 0 && dH > 0, 11, 40 | "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); 41 | 42 | THArgCheck(dilationW > 0 && dilationH > 0, 14, 43 | "dilation should be greater than 0, but got dilationH: %d dilationW: %d", 44 | dilationH, dilationW); 45 | 46 | // int ndim = input->nDimension; 47 | int ndim = THCudaTensor_nDimension(state, input); 48 | int dimf = 0; 49 | int dimh = 1; 50 | int dimw = 2; 51 | 52 | if (ndim == 4) { 53 | dimf++; 54 | dimh++; 55 | dimw++; 56 | } 57 | 58 | THArgCheck(ndim == 3 || ndim == 4, 2, 59 | "3D or 4D input tensor expected but got: %s", ndim); 60 | 61 | // long nInputPlane = weight->size[1]; 62 | // long inputHeight = input->size[dimh]; 63 | // long inputWidth = input->size[dimw]; 64 | // long nOutputPlane = weight->size[0]; 65 | long nInputPlane = THCudaTensor_size(state, weight, 1); 66 | long inputHeight = THCudaTensor_size(state, input, dimh); 67 | long inputWidth = THCudaTensor_size(state, input, dimw); 68 | long nOutputPlane = THCudaTensor_size(state, weight, 0); 69 | long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 70 | long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 71 | 72 | THArgCheck(nInputPlane % deformable_group == 0, 2, 73 | "input channels must divide deformable group size"); 74 | 75 | if (outputWidth < 1 || outputHeight < 1) 76 | THError( 77 | "Given input size: (%ld x %ld x %ld). " 78 | "Calculated output size: (%ld x %ld x %ld). Output size is too small", 79 | nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, 80 | outputWidth); 81 | 82 | THArgCheck(THCudaTensor_size(state, input, 1) == nInputPlane, 2, 83 | "invalid number of input planes, expected: %d, but got: %d", 84 | nInputPlane, THCudaTensor_size(state, input, 1)); 85 | 86 | THArgCheck((inputHeight >= kH && inputWidth >= kW), 2, 87 | "input image is smaller than kernel"); 88 | 89 | // THArgCheck( 90 | // (offset->size[2] == outputHeight && offset->size[3] == outputWidth), 3, 91 | // "invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", outputHeight, outputWidth, 92 | // offset->size[2], offset->size[3]); 93 | THArgCheck( 94 | (THCudaTensor_size(state, offset, 2) == outputHeight && 95 | THCudaTensor_size(state, offset, 3) == outputWidth), 3, 96 | "invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", 97 | outputHeight, outputWidth, THCudaTensor_size(state, offset, 2), 98 | THCudaTensor_size(state, offset, 3)); 99 | 100 | THArgCheck((THCudaTensor_size(state, offset, 1) == deformable_group * 2 * kH * kW), 3, 101 | "invalid number of channels of offset"); 102 | 103 | if (gradOutput != NULL) { 104 | THArgCheck(THCudaTensor_size(state, gradOutput, dimf) == nOutputPlane, 4, 105 | "invalid number of gradOutput planes, expected: %d, but got: %d", 106 | nOutputPlane, THCudaTensor_size(state, gradOutput, dimf)); 107 | 108 | THArgCheck((THCudaTensor_size(state, gradOutput, dimh) == outputHeight && 109 | THCudaTensor_size(state, gradOutput, dimw) == outputWidth), 110 | 4, "invalid size of gradOutput, expected height: %d width: %d , but got height: %d width: %d", 111 | outputHeight, outputWidth, THCudaTensor_size(state, gradOutput, dimh), 112 | THCudaTensor_size(state, gradOutput, dimw)); 113 | } 114 | } 115 | 116 | int deform_conv_forward_cuda(THCudaTensor *input, THCudaTensor *weight, 117 | THCudaTensor *offset, THCudaTensor *output, 118 | THCudaTensor *columns, THCudaTensor *ones, int kW, 119 | int kH, int dW, int dH, int padW, int padH, 120 | int dilationW, int dilationH, 121 | int deformable_group, int im2col_step) { 122 | 123 | // todo: resize columns to include im2col: done 124 | // todo: add im2col_step as input 125 | // todo: add new output buffer and transpose it to output (or directly transpose output) 126 | // todo: possibly change data indexing because of parallel_imgs 127 | 128 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, weight, offset, 129 | output, columns, ones)); 130 | 131 | shape_check(state, input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, 132 | dilationH, dilationW, deformable_group); 133 | 134 | input = THCudaTensor_newContiguous(state, input); 135 | offset = THCudaTensor_newContiguous(state, offset); 136 | weight = THCudaTensor_newContiguous(state, weight); 137 | 138 | int batch = 1; 139 | if (THCudaTensor_nDimension(state, input) == 3) { 140 | // Force batch 141 | batch = 0; 142 | THCudaTensor_resize4d(state, input, 1, THCudaTensor_size(state, input, 0), THCudaTensor_size(state, input, 1), 143 | THCudaTensor_size(state, input, 2)); 144 | THCudaTensor_resize4d(state, offset, 1, THCudaTensor_size(state, offset, 0), THCudaTensor_size(state, offset, 1), 145 | THCudaTensor_size(state, offset, 2)); 146 | } 147 | 148 | // todo: assert batchsize dividable by im2col_step 149 | 150 | long batchSize = THCudaTensor_size(state, input, 0); 151 | long nInputPlane = THCudaTensor_size(state, input, 1); 152 | long inputHeight = THCudaTensor_size(state, input, 2); 153 | long inputWidth = THCudaTensor_size(state, input, 3); 154 | 155 | long nOutputPlane = THCudaTensor_size(state, weight, 0); 156 | 157 | long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 158 | long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 159 | 160 | THArgCheck((THCudaTensor_size(state, offset, 0) == batchSize), 3, "invalid batch size of offset"); 161 | 162 | // bias = bias ? THCudaTensor_newContiguous(state, bias) : bias; 163 | 164 | THCudaTensor_resize5d(state, output, batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth); 165 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth); 166 | 167 | if (THCudaTensor_nDimension(state, ones) != 2 || THCudaTensor_size(state, ones, 0) * 168 | THCudaTensor_size(state, ones, 1) < outputHeight * outputWidth) { 169 | THCudaTensor_resize2d(state, ones, outputHeight, outputWidth); 170 | THCudaTensor_fill(state, ones, 1); 171 | } 172 | 173 | THCudaTensor *input_n = THCudaTensor_new(state); 174 | THCudaTensor *offset_n = THCudaTensor_new(state); 175 | THCudaTensor *output_n = THCudaTensor_new(state); 176 | 177 | THCudaTensor *output_buffer = THCudaTensor_new(state); 178 | THCudaTensor_resize4d(state, output_buffer, batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth); 179 | 180 | THCudaTensor_resize5d(state, input, batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth); 181 | THCudaTensor_resize5d(state, offset, batchSize / im2col_step, im2col_step, 182 | deformable_group * 2 * kH * kW, outputHeight, outputWidth); 183 | 184 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 185 | 186 | THCudaTensor_select(state, input_n, input, 0, elt); 187 | THCudaTensor_select(state, offset_n, offset, 0, elt); 188 | THCudaTensor_select(state, output_n, output_buffer, 0, elt); 189 | 190 | // long m_ = nOutputPlane; 191 | // long n_ = outputHeight * outputWidth; 192 | // long k_ = 1; 193 | 194 | // TODO(BZ) add bias term 195 | // if (bias) { 196 | // THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 197 | // THCudaTensor_data(state, ones), k_, 198 | // THCudaTensor_data(state, bias), k_, 0.0f, 199 | // THCudaTensor_data(state, output_n), n_); 200 | // } else { 201 | // THCudaTensor_zero(state, output_n); 202 | // } 203 | 204 | THCudaTensor_zero(state, output_n); 205 | 206 | deformable_im2col( 207 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n), 208 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 209 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, 210 | im2col_step, deformable_group, THCudaTensor_data(state, columns)); 211 | 212 | long m = nOutputPlane; 213 | long n = THCudaTensor_size(state, columns, 1); // todo: see if we need to change this 214 | long k = nInputPlane * kH * kW; 215 | 216 | // cublas use column major indexing 217 | THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, 218 | THCudaTensor_data(state, columns), n, 219 | THCudaTensor_data(state, weight), k, 1.0f, 220 | THCudaTensor_data(state, output_n), n); 221 | } 222 | 223 | // the reason I use seemingly redundant output_buffer is that THCudaTensor API handles successive transpose and resize poorly 224 | THCudaTensor_resize5d(state, output_buffer, batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth); 225 | THCudaTensor_transpose(state, output_buffer, NULL, 1, 2); 226 | THCudaTensor_copy(state, output, output_buffer); 227 | THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight, outputWidth); 228 | 229 | THCudaTensor_resize4d(state, input, batchSize, nInputPlane, inputHeight, inputWidth); 230 | THCudaTensor_resize4d(state, offset, batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth); 231 | 232 | THCudaTensor_free(state, input_n); 233 | THCudaTensor_free(state, offset_n); 234 | THCudaTensor_free(state, output_n); 235 | THCudaTensor_free(state, output_buffer); 236 | 237 | if (batch == 0) { 238 | THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight, outputWidth); 239 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 240 | THCudaTensor_resize3d(state, offset, THCudaTensor_size(state, offset, 1), 241 | THCudaTensor_size(state, offset, 2), THCudaTensor_size(state, offset, 3)); 242 | } 243 | 244 | THCudaTensor_free(state, input); 245 | THCudaTensor_free(state, offset); 246 | THCudaTensor_free(state, weight); 247 | // if (bias) THCudaTensor_free(state, bias); 248 | 249 | return 1; 250 | } 251 | 252 | int deform_conv_backward_input_cuda( 253 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 254 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight, 255 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH, 256 | int dilationW, int dilationH, int deformable_group, int im2col_step) { 257 | 258 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, gradOutput, weight, 259 | offset, columns, gradInput)); 260 | 261 | shape_check(state, input, offset, gradOutput, weight, kH, kW, dH, dW, padH, 262 | padW, dilationH, dilationW, deformable_group); 263 | 264 | input = THCudaTensor_newContiguous(state, input); 265 | offset = THCudaTensor_newContiguous(state, offset); 266 | gradOutput = THCudaTensor_newContiguous(state, gradOutput); 267 | weight = THCudaTensor_newContiguous(state, weight); 268 | 269 | int batch = 1; 270 | 271 | if (THCudaTensor_nDimension(state, input) == 3) { 272 | // Force batch 273 | batch = 0; 274 | THCudaTensor_resize4d(state, input, 1, THCudaTensor_size(state, input, 0), THCudaTensor_size(state, input, 1), 275 | THCudaTensor_size(state, input, 2)); 276 | THCudaTensor_resize4d(state, offset, 1, THCudaTensor_size(state, offset, 0), THCudaTensor_size(state, offset, 1), 277 | THCudaTensor_size(state, offset, 2)); 278 | THCudaTensor_resize4d(state, gradOutput, 1, THCudaTensor_size(state, gradOutput, 0), 279 | THCudaTensor_size(state, gradOutput, 1), THCudaTensor_size(state, gradOutput, 2)); 280 | } 281 | 282 | long batchSize = THCudaTensor_size(state, input, 0); 283 | long nInputPlane = THCudaTensor_size(state, input, 1); 284 | long inputHeight = THCudaTensor_size(state, input, 2); 285 | long inputWidth = THCudaTensor_size(state, input, 3); 286 | 287 | long nOutputPlane = THCudaTensor_size(state, weight, 0); 288 | 289 | long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 290 | long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 291 | 292 | THArgCheck((THCudaTensor_size(state, offset, 0) == batchSize), 3, "invalid batch size of offset"); 293 | THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth); 294 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth); 295 | 296 | 297 | THCudaTensor *gradInput_n = THCudaTensor_new(state); 298 | THCudaTensor *gradOffset_n = THCudaTensor_new(state); 299 | THCudaTensor *input_n = THCudaTensor_new(state); 300 | THCudaTensor *offset_n = THCudaTensor_new(state); 301 | THCudaTensor *gradOutput_n = THCudaTensor_new(state); 302 | 303 | // change order of grad output 304 | THCudaTensor_resize5d(state, gradOutput, batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth); 305 | THCudaTensor_transpose(state, gradOutput, NULL, 1, 2); 306 | 307 | THCudaTensor *gradOutputBuffer = THCudaTensor_new(state); 308 | THCudaTensor_resize5d(state, gradOutputBuffer, batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth); 309 | THCudaTensor_copy(state, gradOutputBuffer, gradOutput); 310 | THCudaTensor_resize4d(state, gradOutputBuffer, batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth); 311 | 312 | THCudaTensor_transpose(state, gradOutput, NULL, 1, 2); 313 | THCudaTensor_resize4d(state, gradOutput, batchSize, nOutputPlane, outputHeight, outputWidth); 314 | 315 | THCudaTensor_resize5d(state, gradInput, batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth); 316 | THCudaTensor_resize5d(state, input, batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth); 317 | THCudaTensor_resize5d(state, gradOffset, batchSize / im2col_step, im2col_step, 318 | deformable_group * 2 * kH * kW, outputHeight, outputWidth); 319 | THCudaTensor_resize5d(state, offset, batchSize / im2col_step, im2col_step, 320 | deformable_group * 2 * kH * kW, outputHeight, outputWidth); 321 | 322 | 323 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 324 | THCudaTensor_select(state, gradInput_n, gradInput, 0, elt); 325 | THCudaTensor_select(state, gradOffset_n, gradOffset, 0, elt); 326 | THCudaTensor_select(state, input_n, input, 0, elt); 327 | THCudaTensor_select(state, offset_n, offset, 0, elt); 328 | THCudaTensor_select(state, gradOutput_n, gradOutputBuffer, 0, elt); 329 | 330 | long m = nInputPlane * kW * kH; 331 | long n = THCudaTensor_size(state, columns, 1); 332 | long k = nOutputPlane; 333 | 334 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 335 | THCudaTensor_data(state, gradOutput_n), n, 336 | THCudaTensor_data(state, weight), m, 0.0f, 337 | THCudaTensor_data(state, columns), n); 338 | 339 | 340 | deformable_col2im_coord( 341 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns), 342 | THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), 343 | nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, 344 | dilationH, dilationW, im2col_step, deformable_group, 345 | THCudaTensor_data(state, gradOffset_n)); 346 | 347 | deformable_col2im( 348 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns), 349 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 350 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, im2col_step, 351 | deformable_group, THCudaTensor_data(state, gradInput_n)); 352 | } 353 | 354 | THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth); 355 | THCudaTensor_resize4d(state, input, batchSize, nInputPlane, inputHeight, inputWidth); 356 | THCudaTensor_resize4d(state, gradOffset, batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth); 357 | THCudaTensor_resize4d(state, offset, batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth); 358 | 359 | THCudaTensor_free(state, gradInput_n); 360 | THCudaTensor_free(state, gradOffset_n); 361 | THCudaTensor_free(state, input_n); 362 | THCudaTensor_free(state, offset_n); 363 | THCudaTensor_free(state, gradOutput_n); 364 | THCudaTensor_free(state, gradOutputBuffer); 365 | 366 | if (batch == 0) { 367 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, 368 | outputWidth); 369 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 370 | THCudaTensor_resize3d(state, gradInput, nInputPlane, inputHeight, 371 | inputWidth); 372 | THCudaTensor_resize3d(state, offset, THCudaTensor_size(state, offset, 1), THCudaTensor_size(state, offset, 2), 373 | THCudaTensor_size(state, offset, 3)); 374 | THCudaTensor_resize3d(state, gradOffset, THCudaTensor_size(state, offset, 1), THCudaTensor_size(state, offset, 2), 375 | THCudaTensor_size(state, offset, 3)); 376 | } 377 | 378 | THCudaTensor_free(state, input); 379 | THCudaTensor_free(state, offset); 380 | THCudaTensor_free(state, gradOutput); 381 | THCudaTensor_free(state, weight); 382 | 383 | return 1; 384 | } 385 | 386 | int deform_conv_backward_parameters_cuda( 387 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 388 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */ 389 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH, 390 | int padW, int padH, int dilationW, int dilationH, int deformable_group, 391 | float scale, int im2col_step) { 392 | 393 | // todo: transpose and reshape outGrad 394 | // todo: reshape columns 395 | // todo: add im2col_step as input 396 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, offset, gradOutput, 397 | gradWeight, columns)); 398 | 399 | shape_check(state, input, offset, gradOutput, gradWeight, kH, kW, dH, dW, 400 | padH, padW, dilationH, dilationW, deformable_group); 401 | 402 | input = THCudaTensor_newContiguous(state, input); 403 | offset = THCudaTensor_newContiguous(state, offset); 404 | gradOutput = THCudaTensor_newContiguous(state, gradOutput); 405 | 406 | int batch = 1; 407 | 408 | if (THCudaTensor_nDimension(state, input) == 3) { 409 | // Force batch 410 | batch = 0; 411 | THCudaTensor_resize4d(state, input, 1, THCudaTensor_size(state, input, 0), THCudaTensor_size(state, input, 1), 412 | THCudaTensor_size(state, input, 2)); 413 | THCudaTensor_resize4d(state, gradOutput, 1, THCudaTensor_size(state, gradOutput, 0), 414 | THCudaTensor_size(state, gradOutput, 1), THCudaTensor_size(state, gradOutput, 2)); 415 | } 416 | 417 | long batchSize = THCudaTensor_size(state, input, 0); 418 | long nInputPlane = THCudaTensor_size(state, input, 1); 419 | long inputHeight = THCudaTensor_size(state, input, 2); 420 | long inputWidth = THCudaTensor_size(state, input, 3); 421 | 422 | long nOutputPlane = THCudaTensor_size(state, gradWeight, 0); 423 | 424 | long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 425 | long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 426 | 427 | THArgCheck((THCudaTensor_size(state, offset, 0) == batchSize), 3, "invalid batch size of offset"); 428 | 429 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, 430 | im2col_step * outputHeight * outputWidth); 431 | 432 | THCudaTensor *input_n = THCudaTensor_new(state); 433 | THCudaTensor *offset_n = THCudaTensor_new(state); 434 | THCudaTensor *gradOutput_n = THCudaTensor_new(state); 435 | 436 | THCudaTensor_resize5d(state, gradOutput, batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth); 437 | THCudaTensor_transpose(state, gradOutput, NULL, 1, 2); 438 | 439 | THCudaTensor *gradOutputBuffer = THCudaTensor_new(state); 440 | THCudaTensor_resize5d(state, gradOutputBuffer, batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth); 441 | THCudaTensor_copy(state, gradOutputBuffer, gradOutput); 442 | THCudaTensor_resize4d(state, gradOutputBuffer, batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth); 443 | 444 | THCudaTensor_transpose(state, gradOutput, NULL, 1, 2); 445 | THCudaTensor_resize4d(state, gradOutput, batchSize, nOutputPlane, outputHeight, outputWidth); 446 | 447 | 448 | THCudaTensor_resize5d(state, input, batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth); 449 | THCudaTensor_resize5d(state, offset, batchSize / im2col_step, im2col_step, 450 | deformable_group * 2 * kH * kW, outputHeight, outputWidth); 451 | 452 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 453 | THCudaTensor_select(state, input_n, input, 0, elt); 454 | THCudaTensor_select(state, offset_n, offset, 0, elt); 455 | THCudaTensor_select(state, gradOutput_n, gradOutputBuffer, 0, elt); 456 | 457 | deformable_im2col( 458 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n), 459 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 460 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, 461 | im2col_step, deformable_group, THCudaTensor_data(state, columns)); 462 | 463 | long m = nOutputPlane; 464 | long n = nInputPlane * kW * kH; 465 | long k = THCudaTensor_size(state, columns, 1); 466 | 467 | THCudaBlas_Sgemm(state, 't', 'n', n, m, k, scale, 468 | THCudaTensor_data(state, columns), k, 469 | THCudaTensor_data(state, gradOutput_n), k, 1.0f, 470 | THCudaTensor_data(state, gradWeight), n); 471 | } 472 | 473 | THCudaTensor_free(state, input_n); 474 | THCudaTensor_free(state, offset_n); 475 | THCudaTensor_free(state, gradOutput_n); 476 | THCudaTensor_free(state, gradOutputBuffer); 477 | 478 | THCudaTensor_resize4d(state, input, batchSize, nInputPlane, inputHeight, inputWidth); 479 | THCudaTensor_resize4d(state, offset, batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth); 480 | 481 | if (batch == 0) { 482 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, 483 | outputWidth); 484 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 485 | } 486 | 487 | THCudaTensor_free(state, input); 488 | THCudaTensor_free(state, offset); 489 | THCudaTensor_free(state, gradOutput); 490 | return 1; 491 | } 492 | -------------------------------------------------------------------------------- /src/deform_conv_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 3 | * 4 | * COPYRIGHT 5 | * 6 | * All contributions by the University of California: 7 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 8 | * All rights reserved. 9 | * 10 | * All other contributions: 11 | * Copyright (c) 2014-2017, the respective contributors 12 | * All rights reserved. 13 | * 14 | * Caffe uses a shared copyright model: each contributor holds copyright over 15 | * their contributions to Caffe. The project versioning records all such 16 | * contribution and copyright details. If a contributor wants to further mark 17 | * their specific copyright on a particular contribution, they should indicate 18 | * their copyright solely in the commit message of the change when it is 19 | * committed. 20 | * 21 | * LICENSE 22 | * 23 | * Redistribution and use in source and binary forms, with or without 24 | * modification, are permitted provided that the following conditions are met: 25 | * 26 | * 1. Redistributions of source code must retain the above copyright notice, this 27 | * list of conditions and the following disclaimer. 28 | * 2. Redistributions in binary form must reproduce the above copyright notice, 29 | * this list of conditions and the following disclaimer in the documentation 30 | * and/or other materials provided with the distribution. 31 | * 32 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 33 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 34 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 36 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 37 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 38 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 39 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 40 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 41 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | * 43 | * CONTRIBUTION AGREEMENT 44 | * 45 | * By contributing to the BVLC/caffe repository through pull-request, comment, 46 | * or otherwise, the contributor releases their content to the 47 | * license and copyright terms herein. 48 | * 49 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 50 | * 51 | * Copyright (c) 2018 Microsoft 52 | * Licensed under The MIT License [see LICENSE for details] 53 | * \file modulated_deformable_im2col.cuh 54 | * \brief Function definitions of converting an image to 55 | * column matrix based on kernel, padding, dilation, and offset. 56 | * These functions are mainly used in deformable convolution operators. 57 | * \ref: https://arxiv.org/abs/1703.06211 58 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng 59 | */ 60 | 61 | #include "deform_conv_cuda_kernel.h" 62 | #include 63 | #include 64 | 65 | #define CUDA_KERNEL_LOOP(i, n) \ 66 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 67 | i += blockDim.x * gridDim.x) 68 | 69 | const int CUDA_NUM_THREADS = 1024; 70 | const int kMaxGridNum = 65535; 71 | inline int GET_BLOCKS(const int N) { 72 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); 73 | } 74 | 75 | template 76 | __device__ DType deformable_im2col_bilinear(const DType *bottom_data, const int data_width, 77 | const int height, const int width, DType h, DType w) { 78 | 79 | int h_low = floor(h); 80 | int w_low = floor(w); 81 | int h_high = h_low + 1; 82 | int w_high = w_low + 1; 83 | 84 | DType lh = h - h_low; 85 | DType lw = w - w_low; 86 | DType hh = 1 - lh, hw = 1 - lw; 87 | 88 | DType v1 = 0; 89 | if (h_low >= 0 && w_low >= 0) 90 | v1 = bottom_data[h_low * data_width + w_low]; 91 | DType v2 = 0; 92 | if (h_low >=0 && w_high <= width - 1) 93 | v2 = bottom_data[h_low * data_width + w_high]; 94 | DType v3 = 0; 95 | if (h_high <= height - 1 && w_low >= 0) 96 | v3 = bottom_data[h_high * data_width + w_low]; 97 | DType v4 = 0; 98 | if (h_high <= height - 1 && w_high <= width - 1) 99 | v4 = bottom_data[h_high * data_width + w_high]; 100 | 101 | DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 102 | 103 | DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 104 | return val; 105 | } 106 | 107 | 108 | template 109 | __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, 110 | const int h, const int w, const int height, const int width) { 111 | 112 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { 113 | //empty 114 | return 0; 115 | } 116 | 117 | int argmax_h_low = floor(argmax_h); 118 | int argmax_w_low = floor(argmax_w); 119 | int argmax_h_high = argmax_h_low + 1; 120 | int argmax_w_high = argmax_w_low + 1; 121 | 122 | DType weight = 0; 123 | if (h == argmax_h_low && w == argmax_w_low) 124 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 125 | if (h == argmax_h_low && w == argmax_w_high) 126 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 127 | if (h == argmax_h_high && w == argmax_w_low) 128 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 129 | if (h == argmax_h_high && w == argmax_w_high) 130 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 131 | return weight; 132 | } 133 | 134 | 135 | template 136 | __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, 137 | const int height, const int width, const DType *im_data, 138 | const int data_width, const int bp_dir) { 139 | 140 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { 141 | //empty 142 | return 0; 143 | } 144 | 145 | int argmax_h_low = floor(argmax_h); 146 | int argmax_w_low = floor(argmax_w); 147 | int argmax_h_high = argmax_h_low + 1; 148 | int argmax_w_high = argmax_w_low + 1; 149 | 150 | DType weight = 0; 151 | 152 | if (bp_dir == 0) { 153 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 154 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 155 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 156 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 157 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 158 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 159 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 160 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 161 | } else if (bp_dir == 1) { 162 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 163 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 164 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 165 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 166 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 167 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 168 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 169 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 170 | } 171 | 172 | return weight; 173 | } 174 | 175 | 176 | /*! 177 | * \brief deformable_im2col gpu kernel. 178 | * DO NOT call this directly. Use wrapper function im2col() instead; 179 | */ 180 | template 181 | __global__ void deformable_im2col_gpu_kernel(const int n, const DType *data_im, const DType *data_offset, 182 | const int height, const int width, const int kernel_h, const int kernel_w, 183 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 184 | const int dilation_h, const int dilation_w, const int channel_per_deformable_group, 185 | const int batch_size, const int num_channels, const int deformable_group, 186 | const int height_col, const int width_col, 187 | DType *data_col) { 188 | CUDA_KERNEL_LOOP(index, n) { 189 | // index index of output matrix 190 | const int w_col = index % width_col; 191 | const int h_col = (index / width_col) % height_col; 192 | const int b_col = (index / width_col / height_col) % batch_size; 193 | const int c_im = (index / width_col / height_col) / batch_size; 194 | const int c_col = c_im * kernel_h * kernel_w; 195 | 196 | // compute deformable group index 197 | const int deformable_group_index = c_im / channel_per_deformable_group; 198 | 199 | const int h_in = h_col * stride_h - pad_h; 200 | const int w_in = w_col * stride_w - pad_w; 201 | DType* data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; 202 | //const DType* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; 203 | const DType* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; 204 | const DType* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 205 | 206 | 207 | for (int i = 0; i < kernel_h; ++i) { 208 | for (int j = 0; j < kernel_w; ++j) { 209 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 210 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; 211 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 212 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 213 | DType val = static_cast(0); 214 | const DType h_im = h_in + i * dilation_h + offset_h; 215 | const DType w_im = w_in + j * dilation_w + offset_w; 216 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { 217 | //const DType map_h = i * dilation_h + offset_h; 218 | //const DType map_w = j * dilation_w + offset_w; 219 | //const int cur_height = height - h_in; 220 | //const int cur_width = width - w_in; 221 | //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 222 | val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); 223 | } 224 | *data_col_ptr = val; 225 | data_col_ptr += batch_size * height_col * width_col; 226 | } 227 | } 228 | } 229 | } 230 | 231 | 232 | 233 | 234 | 235 | 236 | /*!\brief 237 | * cpu function of deformable_im2col algorithm 238 | * \param s device stream 239 | * \param data_im pointer of images (N, C, H, W, ...) in the image batch 240 | * \param data_offset pointer of offsets (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch 241 | * \param im_shape input image shape in dimensions (N, C, H, W,) 242 | * \param col_shape column buffer shape (#channels, N, output_im_height, output_im_width, ...) 243 | * \param kernel_shape kernel filter shape 244 | * \param pad pad shape 245 | * \param stride stride shape 246 | * \param dilation dilation shape 247 | * \param deformable_group #offset group that deformable convolution use 248 | * \param data_col column buffer pointer 249 | */ 250 | template 251 | inline void deformable_im2col(cudaStream_t stream, 252 | const DType *data_im, const DType *data_offset, const int channels, 253 | const int height, const int width, const int ksize_h, const int ksize_w, 254 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 255 | const int dilation_h, const int dilation_w, const int parallel_imgs, 256 | const int deformable_group, DType *data_col) { 257 | // num_axes should be smaller than block size 258 | // todo: check parallel_imgs is correctly passed in 259 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 260 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 261 | int num_kernels = channels * height_col * width_col * parallel_imgs; 262 | int channel_per_deformable_group = channels / deformable_group; 263 | 264 | //index_t num_spatial_axes = kernel_shape.ndim(); 265 | //CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); 266 | //index_t channel_per_deformable_group = im_shape[1] / deformable_group; 267 | //index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); 268 | //using namespace mxnet_op; 269 | //switch (num_spatial_axes) { 270 | //case 2: 271 | // deformable_im2col_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) 272 | // <<::GetStream(s)>>>( 274 | // num_kernels, data_im, data_offset, im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1], 275 | // pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1], channel_per_deformable_group, 276 | // col_shape[1], im_shape[1], deformable_group, col_shape[2], col_shape[3], data_col); 277 | // MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_im2col_gpu_kernel); 278 | // break; 279 | //default: 280 | // LOG(FATAL) << "im2col_nd_gpu does not support computation with " 281 | // << num_spatial_axes << " spatial axes"; 282 | 283 | deformable_im2col_gpu_kernel<<>>( 284 | num_kernels, data_im, data_offset, height, width, ksize_h, ksize_w, 285 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, 286 | parallel_imgs, channels, deformable_group, height_col, width_col, data_col); 287 | 288 | cudaError_t err = cudaGetLastError(); 289 | if (err != cudaSuccess) { 290 | printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); 291 | } 292 | } 293 | 294 | template void deformable_im2col( 295 | cudaStream_t stream, const float *data_im, const float *data_offset, 296 | const int channels, const int height, const int width, const int ksize_h, 297 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h, 298 | const int stride_w, const int dilation_h, const int dilation_w, 299 | const int parallel_imgs, const int deformable_group, float *data_col); 300 | 301 | /*! 302 | * \brief deformable_col2im gpu kernel. 303 | * \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead; 304 | */ 305 | template 306 | __global__ void deformable_col2im_gpu_kernel(const int n, const DType *data_col, const DType *data_offset, 307 | const int channels, const int height, const int width, 308 | const int kernel_h, const int kernel_w, 309 | const int pad_h, const int pad_w, 310 | const int stride_h, const int stride_w, 311 | const int dilation_h, const int dilation_w, 312 | const int channel_per_deformable_group, 313 | const int batch_size, const int deformable_group, 314 | const int height_col, const int width_col, 315 | DType *grad_im) { 316 | CUDA_KERNEL_LOOP(index, n) { 317 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 318 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 319 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 320 | // compute the start and end of the output 321 | 322 | const int deformable_group_index = c / channel_per_deformable_group; 323 | 324 | int w_out = index % width_col; 325 | int h_out = (index / width_col) % height_col; 326 | int b = (index / width_col / height_col) % batch_size; 327 | int w_in = w_out * stride_w - pad_w; 328 | int h_in = h_out * stride_h - pad_h; 329 | 330 | const DType* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 331 | 2 * kernel_h * kernel_w * height_col * width_col; 332 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 333 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 334 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 335 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 336 | const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; 337 | const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; 338 | 339 | const DType cur_top_grad = data_col[index]; 340 | const int cur_h = (int)cur_inv_h_data; 341 | const int cur_w = (int)cur_inv_w_data; 342 | for (int dy = -2; dy <= 2; dy++) { 343 | for (int dx = -2; dx <= 2; dx++) { 344 | if (cur_h + dy >= 0 && cur_h + dy < height && 345 | cur_w + dx >= 0 && cur_w + dx < width && 346 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 347 | abs(cur_inv_w_data - (cur_w + dx)) < 1 348 | ) { 349 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 350 | DType weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 351 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 352 | } 353 | } 354 | } 355 | } 356 | } 357 | 358 | 359 | /*!\brief 360 | * gpu function of deformable_col2im algorithm 361 | * \param s device stream 362 | * \param data_col start pointer of the column buffer to be filled 363 | * \param data_offset pointer of offsets (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch 364 | * \param im_shape input image shape in dimensions (N, C, H, W,) 365 | * \param col_shape column buffer shape 366 | * \param kernel_shape kernel filter shape 367 | * \param pad pad shape 368 | * \param stride stride shape 369 | * \param dilation dilation shape 370 | * \param deformable_group #offset group that deformable convolution use 371 | * \param grad_im pointer of images (N, C, H, W,...) in the image batch 372 | */ 373 | template 374 | inline void deformable_col2im(cudaStream_t stream, 375 | const DType *data_col, const DType *data_offset, const int channels, 376 | const int height, const int width, const int ksize_h, 377 | const int ksize_w, const int pad_h, const int pad_w, 378 | const int stride_h, const int stride_w, 379 | const int dilation_h, const int dilation_w, 380 | const int parallel_imgs, const int deformable_group, 381 | DType* grad_im) { 382 | 383 | 384 | 385 | // todo: make sure parallel_imgs is passed in correctly 386 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 387 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 388 | int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; 389 | int channel_per_deformable_group = channels / deformable_group; 390 | 391 | // index_t num_spatial_axes = kernel_shape.ndim(); 392 | // index_t im_size = im_shape.ProdShape(1, im_shape.ndim()); 393 | // index_t channel_per_deformable_group = im_shape[1] / deformable_group; 394 | // index_t num_kernels = col_shape.ProdShape(0, col_shape.ndim()); 395 | // num_axes should be smaller than block size 396 | // CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); 397 | // using namespace mxnet_op; 398 | // switch (num_spatial_axes) { 399 | // case 2: 400 | // // To avoid involving atomic operations, we will launch one kernel per 401 | // // bottom dimension, and then in the kernel add up the top dimensions. 402 | // // NOLINT_NEXT_LINE(whitespace/operators) 403 | // deformable_col2im_gpu_kernel<<::GetStream(s)>>>( 405 | // num_kernels, data_col, data_offset, im_shape[1], im_shape[2], im_shape[3], 406 | // kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], 407 | // dilation[0], dilation[1], channel_per_deformable_group, 408 | // col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im, req); 409 | // MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_col2im_gpu_kernel); 410 | // break; 411 | // default: 412 | // LOG(FATAL) << "col2im_nd_gpu does not support computation with " 413 | // << num_spatial_axes << " spatial axes"; 414 | 415 | deformable_col2im_gpu_kernel<<>>( 416 | num_kernels, data_col, data_offset, channels, height, width, ksize_h, 417 | ksize_w, pad_h, pad_w, stride_h, stride_w, 418 | dilation_h, dilation_w, channel_per_deformable_group, 419 | parallel_imgs, deformable_group, height_col, width_col, grad_im); 420 | 421 | cudaError_t err = cudaGetLastError(); 422 | if (err != cudaSuccess) { 423 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); 424 | } 425 | } 426 | 427 | template void deformable_col2im( 428 | cudaStream_t stream, const float *data_col, const float *data_offset, 429 | const int channels, const int height, const int width, const int ksize_h, 430 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h, 431 | const int stride_w, const int dilation_h, const int dilation_w, 432 | const int parallel_imgs, const int deformable_group, float *grad_im); 433 | 434 | /*! 435 | * \brief deformable_col2im_coord gpu kernel. 436 | * \brief DO NOT call this directly. Use wrapper function deformable_col2im_coord() instead; 437 | */ 438 | template 439 | __global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType *data_col, 440 | const DType *data_im, const DType *data_offset, 441 | const int channels, const int height, const int width, 442 | const int kernel_h, const int kernel_w, 443 | const int pad_h, const int pad_w, 444 | const int stride_h, const int stride_w, 445 | const int dilation_h, const int dilation_w, 446 | const int channel_per_deformable_group, 447 | const int batch_size, const int offset_channels, const int deformable_group, 448 | const int height_col, const int width_col, DType *grad_offset) { 449 | CUDA_KERNEL_LOOP(index, n) { 450 | DType val = 0; 451 | int w = index % width_col; 452 | int h = (index / width_col) % height_col; 453 | int c = (index / width_col / height_col) % offset_channels; 454 | int b = (index / width_col / height_col) / offset_channels; 455 | // compute the start and end of the output 456 | 457 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 458 | const int col_step = kernel_h * kernel_w; 459 | int cnt = 0; 460 | const DType *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * 461 | batch_size * width_col * height_col; 462 | const DType *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * 463 | channel_per_deformable_group / kernel_h / kernel_w * height * width; 464 | const DType *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * 465 | kernel_h * kernel_w * height_col * width_col; 466 | 467 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 468 | 469 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { 470 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 471 | const int bp_dir = offset_c % 2; 472 | 473 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 474 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 475 | int w_out = col_pos % width_col; 476 | int h_out = (col_pos / width_col) % height_col; 477 | int w_in = w_out * stride_w - pad_w; 478 | int h_in = h_out * stride_h - pad_h; 479 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 480 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 481 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 482 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 483 | DType inv_h = h_in + i * dilation_h + offset_h; 484 | DType inv_w = w_in + j * dilation_w + offset_w; 485 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { 486 | inv_h = inv_w = -2; 487 | } 488 | const DType weight = get_coordinate_weight( 489 | inv_h, inv_w, 490 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 491 | val += weight * data_col_ptr[col_pos]; 492 | cnt += 1; 493 | } 494 | 495 | grad_offset[index] = val; 496 | } 497 | } 498 | 499 | /*!\brief 500 | * gpu function of deformable_col2im_coord algorithm 501 | * \param s device stream 502 | * \param data_col start pointer of the column buffer to be filled 503 | * \param data_im pointer of images (N, C, H, W, ...) in the image batch 504 | * \param data_offset pointer of offsets (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch 505 | * \param im_shape input image shape in dimensions (N, C, H, W,) 506 | * \param col_shape column buffer shape 507 | * \param kernel_shape kernel filter shape 508 | * \param pad pad shape 509 | * \param stride stride shape 510 | * \param dilation dilation shape 511 | * \param deformable_group #offset group that deformable convolution use 512 | * \param grad_offset pointer of the offsets (N, deformable_group*kernel_h*kernel_w*2, H, W,...) in the offset batch 513 | */ 514 | template 515 | inline void deformable_col2im_coord(cudaStream_t stream, 516 | const DType *data_col, const DType *data_im, const DType *data_offset, const int channels, 517 | const int height, const int width, const int ksize_h, const int ksize_w, 518 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 519 | const int dilation_h, const int dilation_w, const int parallel_imgs, 520 | const int deformable_group, DType *grad_offset) { 521 | 522 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 523 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 524 | int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; 525 | int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; 526 | 527 | // index_t num_spatial_axes = kernel_shape.ndim(); 528 | // index_t num_kernels = col_shape[1] * col_shape[2] * col_shape[3] * 2 * kernel_shape[0] * kernel_shape[1] * deformable_group; 529 | // index_t channel_per_deformable_group = col_shape[0] / deformable_group; 530 | // num_axes should be smaller than block size 531 | // CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); 532 | // using namespace mxnet_op; 533 | // switch (num_spatial_axes) { 534 | // case 2: 535 | // To avoid involving atomic operations, we will launch one kernel per 536 | // bottom dimension, and then in the kernel add up the top dimensions. 537 | // NOLINT_NEXT_LINE(whitespace/operators) 538 | 539 | // deformable_col2im_coord_gpu_kernel << ::GetStream(s) >> >( 541 | // num_kernels, data_col, data_im, data_offset, im_shape[1], im_shape[2], im_shape[3], 542 | // kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], 543 | // dilation[0], dilation[1], channel_per_deformable_group, 544 | // col_shape[1], 2 * kernel_shape[0] * kernel_shape[1] * deformable_group, deformable_group, col_shape[2], col_shape[3], grad_offset, req); 545 | // MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_col2im_coord_gpu_kernel); 546 | // break; 547 | // default: 548 | // LOG(FATAL) << "col2im_nd_gpu does not support computation with " 549 | // << num_spatial_axes << " spatial axes"; 550 | 551 | deformable_col2im_coord_gpu_kernel<<>>( 552 | num_kernels, data_col, data_im, data_offset, channels, height, width, 553 | ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, 554 | dilation_h, dilation_w, channel_per_deformable_group, 555 | parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, 556 | height_col, width_col, grad_offset); 557 | 558 | } 559 | 560 | template void 561 | deformable_col2im_coord(cudaStream_t stream, const float *data_col, 562 | const float *data_im, const float *data_offset, 563 | const int channels, const int height, const int width, 564 | const int ksize_h, const int ksize_w, const int pad_h, 565 | const int pad_w, const int stride_h, const int stride_w, 566 | const int dilation_h, const int dilation_w, const int parallel_imgs, 567 | const int deformable_group, float *grad_offset); 568 | --------------------------------------------------------------------------------