├── .gitignore ├── README.md ├── build.py ├── functions ├── __init__.py └── deform_conv.py ├── make.sh ├── modules ├── __init__.py └── deform_conv.py ├── src ├── deform_conv.c ├── deform_conv.h ├── deform_conv_cuda.c ├── deform_conv_cuda.h ├── deform_conv_cuda_kernel.cu └── deform_conv_cuda_kernel.h └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | **/*.pyc 3 | **/_ext 4 | **/build 5 | **/dist 6 | **/*.egg-info 7 | **/.eggs 8 | .clang_complete 9 | *.o 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deformable Convolutional Networks in PyTorch 2 | This repo is an implementation of [Deformable Convolution](https://arxiv.org/abs/1703.06211). 3 | Ported from author's MXNet [implementation](https://github.com/msracver/Deformable-ConvNets). 4 | 5 | # Build 6 | 7 | ``` 8 | sh make.sh 9 | CC=g++ python build.py 10 | ``` 11 | 12 | See `test.py` for example usage. 13 | 14 | ### Notice 15 | Only `torch.cuda.FloatTensor` is supported. 16 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = ['src/deform_conv.c'] 8 | headers = ['src/deform_conv.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/deform_conv_cuda.c'] 15 | headers += ['src/deform_conv_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | print(this_file) 21 | extra_objects = ['src/deform_conv_cuda_kernel.cu.o'] 22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 23 | 24 | ffi = create_extension( 25 | '_ext.deform_conv', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda, 31 | extra_objects=extra_objects 32 | ) 33 | 34 | if __name__ == '__main__': 35 | assert torch.cuda.is_available(), 'Please install CUDA for GPU support.' 36 | ffi.build() 37 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import conv_offset2d 2 | -------------------------------------------------------------------------------- /functions/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from _ext import deform_conv 6 | 7 | 8 | def conv_offset2d(input, 9 | offset, 10 | weight, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | deform_groups=1): 15 | 16 | if input is not None and input.dim() != 4: 17 | raise ValueError( 18 | "Expected 4D tensor as input, got {}D tensor instead.".format( 19 | input.dim())) 20 | 21 | f = ConvOffset2dFunction( 22 | _pair(stride), _pair(padding), _pair(dilation), deform_groups) 23 | return f(input, offset, weight) 24 | 25 | 26 | class ConvOffset2dFunction(Function): 27 | def __init__(self, stride, padding, dilation, deformable_groups=1): 28 | super(ConvOffset2dFunction, self).__init__() 29 | self.stride = stride 30 | self.padding = padding 31 | self.dilation = dilation 32 | self.deformable_groups = deformable_groups 33 | 34 | def forward(self, input, offset, weight): 35 | self.save_for_backward(input, offset, weight) 36 | 37 | output = input.new(*self._output_size(input, weight)) 38 | 39 | self.bufs_ = [input.new(), input.new()] # columns, ones 40 | 41 | if not input.is_cuda: 42 | raise NotImplementedError 43 | else: 44 | if isinstance(input, torch.autograd.Variable): 45 | if not isinstance(input.data, torch.cuda.FloatTensor): 46 | raise NotImplementedError 47 | else: 48 | if not isinstance(input, torch.cuda.FloatTensor): 49 | raise NotImplementedError 50 | deform_conv.deform_conv_forward_cuda( 51 | input, weight, offset, output, self.bufs_[0], self.bufs_[1], 52 | weight.size(3), weight.size(2), self.stride[1], self.stride[0], 53 | self.padding[1], self.padding[0], self.dilation[1], 54 | self.dilation[0], self.deformable_groups) 55 | return output 56 | 57 | def backward(self, grad_output): 58 | input, offset, weight = self.saved_tensors 59 | 60 | grad_input = grad_offset = grad_weight = None 61 | 62 | if not grad_output.is_cuda: 63 | raise NotImplementedError 64 | else: 65 | if isinstance(grad_output, torch.autograd.Variable): 66 | if not isinstance(grad_output.data, torch.cuda.FloatTensor): 67 | raise NotImplementedError 68 | else: 69 | if not isinstance(grad_output, torch.cuda.FloatTensor): 70 | raise NotImplementedError 71 | if self.needs_input_grad[0] or self.needs_input_grad[1]: 72 | grad_input = input.new(*input.size()).zero_() 73 | grad_offset = offset.new(*offset.size()).zero_() 74 | deform_conv.deform_conv_backward_input_cuda( 75 | input, offset, grad_output, grad_input, 76 | grad_offset, weight, self.bufs_[0], weight.size(3), 77 | weight.size(2), self.stride[1], self.stride[0], 78 | self.padding[1], self.padding[0], self.dilation[1], 79 | self.dilation[0], self.deformable_groups) 80 | 81 | if self.needs_input_grad[2]: 82 | grad_weight = weight.new(*weight.size()).zero_() 83 | deform_conv.deform_conv_backward_parameters_cuda( 84 | input, offset, grad_output, 85 | grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3), 86 | weight.size(2), self.stride[1], self.stride[0], 87 | self.padding[1], self.padding[0], self.dilation[1], 88 | self.dilation[0], self.deformable_groups, 1) 89 | 90 | return grad_input, grad_offset, grad_weight 91 | 92 | def _output_size(self, input, weight): 93 | channels = weight.size(0) 94 | 95 | output_size = (input.size(0), channels) 96 | for d in range(input.dim() - 2): 97 | in_size = input.size(d + 2) 98 | pad = self.padding[d] 99 | kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1 100 | stride = self.stride[d] 101 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) 102 | if not all(map(lambda s: s > 0, output_size)): 103 | raise ValueError( 104 | "convolution input is too small (output would be {})".format( 105 | 'x'.join(map(str, output_size)))) 106 | return output_size 107 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | cd src 2 | nvcc -c -o deform_conv_cuda_kernel.cu.o deform_conv_cuda_kernel.cu -x cu -Xcompiler -fPIC -std=c++11 3 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import ConvOffset2d 2 | -------------------------------------------------------------------------------- /modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.utils import _pair 7 | from functions import conv_offset2d 8 | 9 | 10 | class ConvOffset2d(Module): 11 | def __init__(self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | num_deformable_groups=1): 19 | super(ConvOffset2d, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = _pair(stride) 24 | self.padding = _pair(padding) 25 | self.dilation = _pair(dilation) 26 | self.num_deformable_groups = num_deformable_groups 27 | 28 | self.weight = nn.Parameter( 29 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | n = self.in_channels 35 | for k in self.kernel_size: 36 | n *= k 37 | stdv = 1. / math.sqrt(n) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, input, offset): 41 | return conv_offset2d(input, offset, self.weight, self.stride, 42 | self.padding, self.dilation, 43 | self.num_deformable_groups) 44 | -------------------------------------------------------------------------------- /src/deform_conv.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 4 | THFloatTensor *output) 5 | { 6 | // if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | // return 0; 8 | // THFloatTensor_resizeAs(output, input); 9 | // THFloatTensor_cadd(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 14 | THFloatTensor *grad_offset) 15 | { 16 | // THFloatTensor_resizeAs(grad_input, grad_output); 17 | // THFloatTensor_fill(grad_input, 1); 18 | return 1; 19 | } 20 | -------------------------------------------------------------------------------- /src/deform_conv.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 2 | THFloatTensor *output); 3 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 4 | THFloatTensor *grad_offset); 5 | -------------------------------------------------------------------------------- /src/deform_conv_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "deform_conv_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | void shape_check(THCState *state, THCudaTensor *input, THCudaTensor *offset, 8 | THCudaTensor *gradOutput, THCudaTensor *weight, int kH, int kW, 9 | int dH, int dW, int padH, int padW, int dilationH, 10 | int dilationW, int deformable_group) { 11 | 12 | THArgCheck(weight->nDimension == 4, 5, 13 | "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " 14 | "but got: %s", 15 | weight->nDimension); 16 | 17 | THArgCheck(THCudaTensor_isContiguous(state, weight), 5, 18 | "weight tensor has to be contiguous"); 19 | 20 | THArgCheck(kW > 0 && kH > 0, 9, 21 | "kernel size should be greater than zero, but got kH: %d kW: %d", 22 | kH, kW); 23 | 24 | THArgCheck((weight->size[2] == kH && weight->size[3] == kW), 9, 25 | "kernel size should be consistent with weight, ", 26 | "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, 27 | kW, weight->size[2], weight->size[3]); 28 | 29 | THArgCheck(dW > 0 && dH > 0, 11, 30 | "stride should be greater than zero, but got dH: %d dW: %d", dH, 31 | dW); 32 | 33 | THArgCheck( 34 | dilationW > 0 && dilationH > 0, 14, 35 | "dilation should be greater than 0, but got dilationH: %d dilationW: %d", 36 | dilationH, dilationW); 37 | 38 | int ndim = input->nDimension; 39 | int dimf = 0; 40 | int dimh = 1; 41 | int dimw = 2; 42 | 43 | if (ndim == 4) { 44 | dimf++; 45 | dimh++; 46 | dimw++; 47 | } 48 | 49 | THArgCheck(ndim == 3 || ndim == 4, 2, 50 | "3D or 4D input tensor expected but got: %s", ndim); 51 | 52 | long nInputPlane = weight->size[1]; 53 | long inputHeight = input->size[dimh]; 54 | long inputWidth = input->size[dimw]; 55 | long nOutputPlane = weight->size[0]; 56 | long outputHeight = 57 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 58 | long outputWidth = 59 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 60 | 61 | THArgCheck(nInputPlane % deformable_group == 0, 2, 62 | "input channels must divide deformable group size"); 63 | 64 | if (outputWidth < 1 || outputHeight < 1) 65 | THError( 66 | "Given input size: (%ld x %ld x %ld). " 67 | "Calculated output size: (%ld x %ld x %ld). Output size is too small", 68 | nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, 69 | outputWidth); 70 | 71 | THArgCheck(input->size[1] == nInputPlane, 2, 72 | "invalid number of input planes, expected: %d, but got: %d", 73 | nInputPlane, input->size[1]); 74 | 75 | THArgCheck((inputHeight >= kH && inputWidth >= kW), 2, 76 | "input image is smaller than kernel"); 77 | 78 | THArgCheck( 79 | (offset->size[2] == outputHeight && offset->size[3] == outputWidth), 3, 80 | "invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", outputHeight, outputWidth, 81 | offset->size[2], offset->size[3]); 82 | 83 | THArgCheck((offset->size[1] == deformable_group * 2 * kH * kW), 3, 84 | "invalid number of channels of offset"); 85 | 86 | if (gradOutput != NULL) { 87 | THArgCheck(gradOutput->size[dimf] == nOutputPlane, 4, 88 | "invalid number of gradOutput planes, expected: %d, but got: %d", 89 | nOutputPlane, gradOutput->size[dimf]); 90 | 91 | THArgCheck((gradOutput->size[dimh] == outputHeight && 92 | gradOutput->size[dimw] == outputWidth), 93 | 4, "invalid size of gradOutput, expected height: %d width: %d , but got height: %d width: %d", outputHeight, outputWidth, 94 | gradOutput->size[dimh], gradOutput->size[dimw]); 95 | } 96 | } 97 | 98 | int deform_conv_forward_cuda(THCudaTensor *input, THCudaTensor *weight, 99 | THCudaTensor *offset, THCudaTensor *output, 100 | THCudaTensor *columns, THCudaTensor *ones, int kW, 101 | int kH, int dW, int dH, int padW, int padH, 102 | int dilationH, int dilationW, 103 | int deformable_group) { 104 | 105 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, weight, offset, 106 | output, columns, ones)); 107 | 108 | shape_check(state, input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, 109 | dilationH, dilationW, deformable_group); 110 | 111 | input = THCudaTensor_newContiguous(state, input); 112 | offset = THCudaTensor_newContiguous(state, offset); 113 | weight = THCudaTensor_newContiguous(state, weight); 114 | 115 | int batch = 1; 116 | if (input->nDimension == 3) { 117 | // Force batch 118 | batch = 0; 119 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], 120 | input->size[2]); 121 | THCudaTensor_resize4d(state, offset, 1, offset->size[0], offset->size[1], 122 | offset->size[2]); 123 | } 124 | 125 | long batchSize = input->size[0]; 126 | long nInputPlane = input->size[1]; 127 | long inputHeight = input->size[2]; 128 | long inputWidth = input->size[3]; 129 | 130 | long nOutputPlane = weight->size[0]; 131 | 132 | long outputWidth = 133 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 134 | long outputHeight = 135 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 136 | 137 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset"); 138 | 139 | // bias = bias ? THCudaTensor_newContiguous(state, bias) : bias; 140 | 141 | THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight, 142 | outputWidth); 143 | 144 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, 145 | outputHeight * outputWidth); 146 | 147 | if (ones->nDimension != 2 || 148 | ones->size[0] * ones->size[1] < outputHeight * outputWidth) { 149 | THCudaTensor_resize2d(state, ones, outputHeight, outputWidth); 150 | THCudaTensor_fill(state, ones, 1); 151 | } 152 | 153 | THCudaTensor *input_n = THCudaTensor_new(state); 154 | THCudaTensor *offset_n = THCudaTensor_new(state); 155 | THCudaTensor *output_n = THCudaTensor_new(state); 156 | 157 | for (int elt = 0; elt < batchSize; elt++) { 158 | 159 | THCudaTensor_select(state, input_n, input, 0, elt); 160 | THCudaTensor_select(state, offset_n, offset, 0, elt); 161 | THCudaTensor_select(state, output_n, output, 0, elt); 162 | 163 | // long m_ = nOutputPlane; 164 | // long n_ = outputHeight * outputWidth; 165 | // long k_ = 1; 166 | 167 | // TODO(BZ) add bias term 168 | // if (bias) { 169 | // THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 170 | // THCudaTensor_data(state, ones), k_, 171 | // THCudaTensor_data(state, bias), k_, 0.0f, 172 | // THCudaTensor_data(state, output_n), n_); 173 | // } else { 174 | // THCudaTensor_zero(state, output_n); 175 | // } 176 | 177 | THCudaTensor_zero(state, output_n); 178 | 179 | deformable_im2col( 180 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n), 181 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 182 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, 183 | deformable_group, THCudaTensor_data(state, columns)); 184 | 185 | long m = nOutputPlane; 186 | long n = columns->size[1]; 187 | long k = nInputPlane * kH * kW; 188 | 189 | THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, 190 | THCudaTensor_data(state, columns), n, 191 | THCudaTensor_data(state, weight), k, 1.0f, 192 | THCudaTensor_data(state, output_n), n); 193 | } 194 | 195 | THCudaTensor_free(state, input_n); 196 | THCudaTensor_free(state, offset_n); 197 | THCudaTensor_free(state, output_n); 198 | 199 | if (batch == 0) { 200 | THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight, 201 | outputWidth); 202 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 203 | THCudaTensor_resize3d(state, offset, offset->size[1], offset->size[2], 204 | offset->size[3]); 205 | } 206 | 207 | THCudaTensor_free(state, input); 208 | THCudaTensor_free(state, offset); 209 | THCudaTensor_free(state, weight); 210 | // if (bias) THCudaTensor_free(state, bias); 211 | 212 | return 1; 213 | } 214 | 215 | int deform_conv_backward_input_cuda( 216 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 217 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight, 218 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH, 219 | int dilationH, int dilationW, int deformable_group) { 220 | 221 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, gradOutput, weight, 222 | offset, columns, gradInput)); 223 | 224 | shape_check(state, input, offset, gradOutput, weight, kH, kW, dH, dW, padH, 225 | padW, dilationH, dilationW, deformable_group); 226 | 227 | input = THCudaTensor_newContiguous(state, input); 228 | offset = THCudaTensor_newContiguous(state, offset); 229 | gradOutput = THCudaTensor_newContiguous(state, gradOutput); 230 | weight = THCudaTensor_newContiguous(state, weight); 231 | 232 | int batch = 1; 233 | if (input->nDimension == 3) { 234 | // Force batch 235 | batch = 0; 236 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], 237 | input->size[2]); 238 | THCudaTensor_resize4d(state, offset, 1, offset->size[0], offset->size[1], 239 | offset->size[2]); 240 | THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], 241 | gradOutput->size[1], gradOutput->size[2]); 242 | } 243 | 244 | long batchSize = input->size[0]; 245 | long nInputPlane = input->size[1]; 246 | long inputHeight = input->size[2]; 247 | long inputWidth = input->size[3]; 248 | 249 | long nOutputPlane = weight->size[0]; 250 | 251 | long outputWidth = 252 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 253 | long outputHeight = 254 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 255 | 256 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset"); 257 | 258 | THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight, 259 | inputWidth); 260 | 261 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, 262 | outputHeight * outputWidth); 263 | 264 | THCudaTensor *gradInput_n = THCudaTensor_new(state); 265 | THCudaTensor *gradOffset_n = THCudaTensor_new(state); 266 | THCudaTensor *input_n = THCudaTensor_new(state); 267 | THCudaTensor *offset_n = THCudaTensor_new(state); 268 | THCudaTensor *gradOutput_n = THCudaTensor_new(state); 269 | 270 | for (int elt = 0; elt < batchSize; elt++) { 271 | THCudaTensor_select(state, gradInput_n, gradInput, 0, elt); 272 | THCudaTensor_select(state, gradOffset_n, gradOffset, 0, elt); 273 | THCudaTensor_select(state, input_n, input, 0, elt); 274 | THCudaTensor_select(state, offset_n, offset, 0, elt); 275 | THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt); 276 | 277 | long m = nInputPlane * kW * kH; 278 | long n = columns->size[1]; 279 | long k = nOutputPlane; 280 | 281 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 282 | THCudaTensor_data(state, gradOutput_n), n, 283 | THCudaTensor_data(state, weight), m, 0.0f, 284 | THCudaTensor_data(state, columns), n); 285 | 286 | deformable_col2im_coord( 287 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns), 288 | THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), 289 | nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, 290 | dilationH, dilationW, deformable_group, 291 | THCudaTensor_data(state, gradOffset_n)); 292 | 293 | deformable_col2im( 294 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns), 295 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 296 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, 297 | deformable_group, THCudaTensor_data(state, gradInput_n)); 298 | } 299 | 300 | THCudaTensor_free(state, gradInput_n); 301 | THCudaTensor_free(state, gradOffset_n); 302 | THCudaTensor_free(state, input_n); 303 | THCudaTensor_free(state, offset_n); 304 | THCudaTensor_free(state, gradOutput_n); 305 | 306 | if (batch == 0) { 307 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, 308 | outputWidth); 309 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 310 | THCudaTensor_resize3d(state, gradInput, nInputPlane, inputHeight, 311 | inputWidth); 312 | THCudaTensor_resize3d(state, offset, offset->size[1], offset->size[2], 313 | offset->size[3]); 314 | THCudaTensor_resize3d(state, gradOffset, offset->size[1], offset->size[2], 315 | offset->size[3]); 316 | } 317 | 318 | THCudaTensor_free(state, input); 319 | THCudaTensor_free(state, offset); 320 | THCudaTensor_free(state, gradOutput); 321 | THCudaTensor_free(state, weight); 322 | 323 | return 1; 324 | } 325 | 326 | int deform_conv_backward_parameters_cuda( 327 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 328 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */ 329 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH, 330 | int padW, int padH, int dilationH, int dilationW, int deformable_group, 331 | float scale) { 332 | 333 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, offset, gradOutput, 334 | gradWeight, columns)); 335 | 336 | shape_check(state, input, offset, gradOutput, gradWeight, kH, kW, dH, dW, 337 | padH, padW, dilationH, dilationW, deformable_group); 338 | 339 | input = THCudaTensor_newContiguous(state, input); 340 | offset = THCudaTensor_newContiguous(state, offset); 341 | gradOutput = THCudaTensor_newContiguous(state, gradOutput); 342 | 343 | int batch = 1; 344 | if (input->nDimension == 3) { 345 | // Force batch 346 | batch = 0; 347 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], 348 | input->size[2]); 349 | THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], 350 | gradOutput->size[1], gradOutput->size[2]); 351 | } 352 | 353 | long batchSize = input->size[0]; 354 | long nInputPlane = input->size[1]; 355 | long inputHeight = input->size[2]; 356 | long inputWidth = input->size[3]; 357 | 358 | long nOutputPlane = gradWeight->size[0]; 359 | 360 | long outputWidth = 361 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 362 | long outputHeight = 363 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 364 | 365 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset"); 366 | 367 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH, 368 | outputHeight * outputWidth); 369 | 370 | THCudaTensor *input_n = THCudaTensor_new(state); 371 | THCudaTensor *offset_n = THCudaTensor_new(state); 372 | THCudaTensor *gradOutput_n = THCudaTensor_new(state); 373 | 374 | for (int elt = 0; elt < batchSize; elt++) { 375 | THCudaTensor_select(state, input_n, input, 0, elt); 376 | THCudaTensor_select(state, offset_n, offset, 0, elt); 377 | THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt); 378 | 379 | deformable_im2col( 380 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n), 381 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight, 382 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, 383 | deformable_group, THCudaTensor_data(state, columns)); 384 | 385 | long m = nOutputPlane; 386 | long n = nInputPlane * kW * kH; 387 | long k = columns->size[1]; 388 | 389 | THCudaBlas_Sgemm(state, 't', 'n', n, m, k, scale, 390 | THCudaTensor_data(state, columns), k, 391 | THCudaTensor_data(state, gradOutput_n), k, 1.0f, 392 | THCudaTensor_data(state, gradWeight), n); 393 | } 394 | 395 | THCudaTensor_free(state, input_n); 396 | THCudaTensor_free(state, offset_n); 397 | THCudaTensor_free(state, gradOutput_n); 398 | 399 | if (batch == 0) { 400 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, 401 | outputWidth); 402 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 403 | } 404 | 405 | THCudaTensor_free(state, input); 406 | THCudaTensor_free(state, offset); 407 | THCudaTensor_free(state, gradOutput); 408 | return 1; 409 | } 410 | -------------------------------------------------------------------------------- /src/deform_conv_cuda.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward_cuda(THCudaTensor *input, 2 | THCudaTensor *weight, /*THCudaTensor * bias, */ 3 | THCudaTensor *offset, THCudaTensor *output, 4 | THCudaTensor *columns, THCudaTensor *ones, int kW, 5 | int kH, int dW, int dH, int padW, int padH, 6 | int dilationH, int dilationW, 7 | int deformable_group); 8 | 9 | int deform_conv_backward_input_cuda( 10 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 11 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight, 12 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH, 13 | int dilationH, int dilationW, int deformable_group); 14 | 15 | int deform_conv_backward_parameters_cuda( 16 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 17 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */ 18 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH, 19 | int padW, int padH, int dilationH, int dilationW, int deformable_group, 20 | float scale); 21 | -------------------------------------------------------------------------------- /src/deform_conv_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "deform_conv_cuda_kernel.h" 2 | 3 | #include 4 | 5 | #define CUDA_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | const int CUDA_NUM_THREADS = 1024; 10 | 11 | inline int GET_BLOCKS(const int N) { 12 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 13 | } 14 | 15 | template 16 | __device__ DType deformable_im2col_bilinear(const DType *bottom_data, 17 | const int data_width, 18 | const int height, const int width, 19 | DType h, DType w) { 20 | 21 | int h_low = floor(h); 22 | int w_low = floor(w); 23 | int h_high; 24 | int w_high; 25 | if (h_low >= height - 1) { 26 | h_high = h_low = height - 1; 27 | h = (DType)h_low; 28 | } else { 29 | h_high = h_low + 1; 30 | } 31 | 32 | if (w_low >= width - 1) { 33 | w_high = w_low = width - 1; 34 | w = (DType)w_low; 35 | } else { 36 | w_high = w_low + 1; 37 | } 38 | 39 | DType lh = h - h_low; 40 | DType lw = w - w_low; 41 | DType hh = 1 - lh, hw = 1 - lw; 42 | 43 | DType v1 = bottom_data[h_low * data_width + w_low]; 44 | DType v2 = bottom_data[h_low * data_width + w_high]; 45 | DType v3 = bottom_data[h_high * data_width + w_low]; 46 | DType v4 = bottom_data[h_high * data_width + w_high]; 47 | DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 48 | 49 | DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 50 | return val; 51 | } 52 | 53 | template 54 | __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, 55 | const int h, const int w, const int height, 56 | const int width) { 57 | 58 | if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { 59 | // empty 60 | return 0; 61 | } 62 | 63 | argmax_h = max(argmax_h, (DType)0.0f); 64 | argmax_w = max(argmax_w, (DType)0.0f); 65 | 66 | int argmax_h_low = (int)argmax_h; 67 | int argmax_w_low = (int)argmax_w; 68 | int argmax_h_high; 69 | int argmax_w_high; 70 | if (argmax_h_low >= height - 1) { 71 | argmax_h_high = argmax_h_low = height - 1; 72 | argmax_h = (DType)argmax_h_low; 73 | } else { 74 | argmax_h_high = argmax_h_low + 1; 75 | } 76 | if (argmax_w_low >= width - 1) { 77 | argmax_w_high = argmax_w_low = width - 1; 78 | argmax_w = (DType)argmax_w_low; 79 | } else { 80 | argmax_w_high = argmax_w_low + 1; 81 | } 82 | DType weight = 0; 83 | if (h == argmax_h_low) { 84 | if (w == argmax_w_low) { 85 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 86 | } else if (w == argmax_w_high) { 87 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 88 | } 89 | } else if (h == argmax_h_high) { 90 | if (w == argmax_w_low) { 91 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 92 | } else if (w == argmax_w_high) { 93 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 94 | } 95 | } 96 | return weight; 97 | } 98 | 99 | template 100 | __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, 101 | const int height, const int width, 102 | const DType *im_data, 103 | const int data_width, const int bp_dir) { 104 | 105 | if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { 106 | // empty 107 | return 0; 108 | } 109 | 110 | if (argmax_h < 0) 111 | argmax_h = 0; 112 | if (argmax_w < 0) 113 | argmax_w = 0; 114 | 115 | int argmax_h_low = (int)argmax_h; 116 | int argmax_w_low = (int)argmax_w; 117 | int argmax_h_high; 118 | int argmax_w_high; 119 | if (argmax_h_low >= height - 1) { 120 | argmax_h_high = argmax_h_low = height - 1; 121 | argmax_h = (DType)argmax_h_low; 122 | } else { 123 | argmax_h_high = argmax_h_low + 1; 124 | } 125 | if (argmax_w_low >= width - 1) { 126 | argmax_w_high = argmax_w_low = width - 1; 127 | argmax_w = (DType)argmax_w_low; 128 | } else { 129 | argmax_w_high = argmax_w_low + 1; 130 | } 131 | DType weight = 0; 132 | 133 | if (bp_dir == 0) { 134 | weight += -1 * (argmax_w_low + 1 - argmax_w) * 135 | im_data[argmax_h_low * data_width + argmax_w_low]; 136 | weight += -1 * (argmax_w - argmax_w_low) * 137 | im_data[argmax_h_low * data_width + argmax_w_high]; 138 | weight += (argmax_w_low + 1 - argmax_w) * 139 | im_data[argmax_h_high * data_width + argmax_w_low]; 140 | weight += (argmax_w - argmax_w_low) * 141 | im_data[argmax_h_high * data_width + argmax_w_high]; 142 | } else if (bp_dir == 1) { 143 | weight += -1 * (argmax_h_low + 1 - argmax_h) * 144 | im_data[argmax_h_low * data_width + argmax_w_low]; 145 | weight += (argmax_h_low + 1 - argmax_h) * 146 | im_data[argmax_h_low * data_width + argmax_w_high]; 147 | weight += -1 * (argmax_h - argmax_h_low) * 148 | im_data[argmax_h_high * data_width + argmax_w_low]; 149 | weight += (argmax_h - argmax_h_low) * 150 | im_data[argmax_h_high * data_width + argmax_w_high]; 151 | } 152 | 153 | return weight; 154 | } 155 | 156 | template 157 | __global__ void deformable_im2col_gpu_kernel( 158 | const int n, const DType *data_im, const DType *data_offset, 159 | const int height, const int width, const int kernel_h, const int kernel_w, 160 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 161 | const int dilation_h, const int dilation_w, 162 | const int channel_per_deformable_group, const int height_col, 163 | const int width_col, DType *data_col) { 164 | CUDA_KERNEL_LOOP(index, n) { 165 | // index index of output matrix 166 | const int w_col = index % width_col; 167 | const int h_col = (index / width_col) % height_col; 168 | const int c_im = (index / width_col) / height_col; 169 | const int c_col = c_im * kernel_h * kernel_w; 170 | 171 | // compute deformable group index 172 | const int deformable_group_index = c_im / channel_per_deformable_group; 173 | 174 | const int h_in = h_col * stride_h - pad_h; 175 | const int w_in = w_col * stride_w - pad_w; 176 | DType *data_col_ptr = 177 | data_col + (c_col * height_col + h_col) * width_col + w_col; 178 | const DType *data_im_ptr = data_im + (c_im * height + h_in) * width + w_in; 179 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 * 180 | kernel_h * kernel_w * 181 | height_col * width_col; 182 | 183 | for (int i = 0; i < kernel_h; ++i) { 184 | for (int j = 0; j < kernel_w; ++j) { 185 | const int data_offset_h_ptr = 186 | ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 187 | const int data_offset_w_ptr = 188 | ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + 189 | w_col; 190 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 191 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 192 | DType val = static_cast(0); 193 | const DType h_im = h_in + i * dilation_h + offset_h; 194 | const DType w_im = w_in + j * dilation_w + offset_w; 195 | if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 196 | const DType map_h = i * dilation_h + offset_h; 197 | const DType map_w = j * dilation_w + offset_w; 198 | const int cur_height = height - h_in; 199 | const int cur_width = width - w_in; 200 | val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, 201 | cur_width, map_h, map_w); 202 | } 203 | *data_col_ptr = val; 204 | data_col_ptr += height_col * width_col; 205 | } 206 | } 207 | } 208 | } 209 | 210 | template 211 | void deformable_im2col(cudaStream_t stream, const DType *data_im, 212 | const DType *data_offset, const int channels, 213 | const int height, const int width, const int ksize_h, 214 | const int ksize_w, const int pad_h, const int pad_w, 215 | const int stride_h, const int stride_w, 216 | const int dilation_h, const int dilation_w, 217 | const int deformable_group, DType *data_col) { 218 | // We are going to launch channels * height_col * width_col kernels, each 219 | // kernel responsible for copying a single-channel grid. 220 | int height_col = 221 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 222 | int width_col = 223 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 224 | int num_kernels = channels * height_col * width_col; 225 | int channel_per_deformable_group = channels / deformable_group; 226 | // Launch 227 | deformable_im2col_gpu_kernel<<>>( 229 | num_kernels, data_im, data_offset, height, width, ksize_h, ksize_w, pad_h, 230 | pad_w, stride_h, stride_w, dilation_h, dilation_w, 231 | channel_per_deformable_group, height_col, width_col, data_col); 232 | 233 | cudaError_t err = cudaGetLastError(); 234 | if (err != cudaSuccess) { 235 | printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); 236 | // TODO(BZ) panic 237 | } 238 | } 239 | 240 | template void deformable_im2col( 241 | cudaStream_t stream, const float *data_im, const float *data_offset, 242 | const int channels, const int height, const int width, const int ksize_h, 243 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h, 244 | const int stride_w, const int dilation_h, const int dilation_w, 245 | const int deformable_group, float *data_col); 246 | 247 | template 248 | __global__ void deformable_col2im_gpu_kernel( 249 | const int n, const DType *data_col, const DType *data_offset, 250 | const int channels, const int height, const int width, const int kernel_h, 251 | const int kernel_w, const int pad_h, const int pad_w, const int stride_h, 252 | const int stride_w, const int dilation_h, const int dilation_w, 253 | const int channel_per_deformable_group, const int height_col, 254 | const int width_col, DType *grad_im) { 255 | CUDA_KERNEL_LOOP(index, n) { 256 | const int j = (index / width_col / height_col) % kernel_w; 257 | const int i = (index / width_col / height_col / kernel_w) % kernel_h; 258 | const int c = index / width_col / height_col / kernel_w / kernel_h; 259 | // compute the start and end of the output 260 | 261 | const int deformable_group_index = c / channel_per_deformable_group; 262 | 263 | int w_out = index % width_col; 264 | int h_out = (index / width_col) % height_col; 265 | int w_in = w_out * stride_w - pad_w; 266 | int h_in = h_out * stride_h - pad_h; 267 | 268 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 * 269 | kernel_h * kernel_w * 270 | height_col * width_col; 271 | const int data_offset_h_ptr = 272 | ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 273 | const int data_offset_w_ptr = 274 | ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 275 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 276 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 277 | const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; 278 | const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; 279 | 280 | const DType cur_top_grad = data_col[index]; 281 | const int cur_h = (int)cur_inv_h_data; 282 | const int cur_w = (int)cur_inv_w_data; 283 | for (int dy = -2; dy <= 2; dy++) { 284 | for (int dx = -2; dx <= 2; dx++) { 285 | if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && 286 | cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && 287 | abs(cur_inv_w_data - (cur_w + dx)) < 1) { 288 | int cur_bottom_grad_pos = 289 | (c * height + cur_h + dy) * width + cur_w + dx; 290 | DType weight = 291 | get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, 292 | cur_w + dx, height, width); 293 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 294 | } 295 | } 296 | } 297 | } 298 | } 299 | 300 | template 301 | void deformable_col2im(cudaStream_t stream, const DType *data_col, 302 | const DType *data_offset, const int channels, 303 | const int height, const int width, const int ksize_h, 304 | const int ksize_w, const int pad_h, const int pad_w, 305 | const int stride_h, const int stride_w, 306 | const int dilation_h, const int dilation_w, 307 | const int deformable_group, DType *grad_im) { 308 | 309 | int height_col = 310 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 311 | int width_col = 312 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 313 | int num_kernels = channels * ksize_h * ksize_w * height_col * width_col; 314 | int channel_per_deformable_group = channels / deformable_group; 315 | // To avoid involving atomic operations, we will launch one kernel per 316 | // bottom dimension, and then in the kernel add up the top dimensions. 317 | deformable_col2im_gpu_kernel<<>>( 319 | num_kernels, data_col, data_offset, channels, height, width, ksize_h, 320 | ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 321 | channel_per_deformable_group, height_col, width_col, grad_im); 322 | 323 | cudaError_t err = cudaGetLastError(); 324 | if (err != cudaSuccess) { 325 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); 326 | // TODO(BZ) panic 327 | } 328 | } 329 | 330 | template void deformable_col2im( 331 | cudaStream_t stream, const float *data_col, const float *data_offset, 332 | const int channels, const int height, const int width, const int ksize_h, 333 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h, 334 | const int stride_w, const int dilation_h, const int dilation_w, 335 | const int deformable_group, float *grad_im); 336 | 337 | template 338 | __global__ void deformable_col2im_coord_gpu_kernel( 339 | const int n, const DType *data_col, const DType *data_im, 340 | const DType *data_offset, const int channels, const int height, 341 | const int width, const int kernel_h, const int kernel_w, const int pad_h, 342 | const int pad_w, const int stride_h, const int stride_w, 343 | const int dilation_h, const int dilation_w, 344 | const int channel_per_deformable_group, const int height_col, 345 | const int width_col, DType *grad_offset) { 346 | CUDA_KERNEL_LOOP(index, n) { 347 | DType val = 0; 348 | int w = index % width_col; 349 | int h = (index / width_col) % height_col; 350 | int c = index / width_col / height_col; 351 | // compute the start and end of the output 352 | 353 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 354 | const int col_step = kernel_h * kernel_w; 355 | int cnt = 0; 356 | const DType *data_col_ptr = data_col + deformable_group_index * 357 | channel_per_deformable_group * 358 | width_col * height_col; 359 | const DType *data_im_ptr = 360 | data_im + deformable_group_index * channel_per_deformable_group / 361 | kernel_h / kernel_w * height * width; 362 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 * 363 | kernel_h * kernel_w * 364 | height_col * width_col; 365 | 366 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 367 | 368 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; 369 | col_c += col_step) { 370 | const int col_pos = ((col_c * height_col) + h) * width_col + w; 371 | const int bp_dir = offset_c % 2; 372 | 373 | int j = (col_pos / width_col / height_col) % kernel_w; 374 | int i = (col_pos / width_col / height_col / kernel_w) % kernel_h; 375 | int w_out = col_pos % width_col; 376 | int h_out = (col_pos / width_col) % height_col; 377 | int w_in = w_out * stride_w - pad_w; 378 | int h_in = h_out * stride_h - pad_h; 379 | const int data_offset_h_ptr = 380 | (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 381 | const int data_offset_w_ptr = 382 | (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + 383 | w_out); 384 | const DType offset_h = data_offset_ptr[data_offset_h_ptr]; 385 | const DType offset_w = data_offset_ptr[data_offset_w_ptr]; 386 | DType inv_h = h_in + i * dilation_h + offset_h; 387 | DType inv_w = w_in + j * dilation_w + offset_w; 388 | if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) { 389 | inv_h = inv_w = -1; 390 | } 391 | const DType weight = get_coordinate_weight( 392 | inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, 393 | width, bp_dir); 394 | val += weight * data_col_ptr[col_pos]; 395 | cnt += 1; 396 | } 397 | 398 | grad_offset[index] = val; 399 | } 400 | } 401 | 402 | template 403 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col, 404 | const DType *data_im, const DType *data_offset, 405 | const int channels, const int height, 406 | const int width, const int ksize_h, 407 | const int ksize_w, const int pad_h, 408 | const int pad_w, const int stride_h, 409 | const int stride_w, const int dilation_h, 410 | const int dilation_w, const int deformable_group, 411 | DType *grad_offset) { 412 | 413 | int height_col = 414 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 415 | int width_col = 416 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 417 | int num_kernels = 418 | height_col * width_col * 2 * ksize_h * ksize_w * deformable_group; 419 | int channel_per_deformable_group = 420 | channels * ksize_h * ksize_w / deformable_group; 421 | // To avoid involving atomic operations, we will launch one kernel per 422 | // bottom dimension, and then in the kernel add up the top dimensions. 423 | deformable_col2im_coord_gpu_kernel<<>>( 425 | num_kernels, data_col, data_im, data_offset, channels, height, width, 426 | ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, 427 | dilation_w, channel_per_deformable_group, height_col, width_col, 428 | grad_offset); 429 | 430 | cudaError_t err = cudaGetLastError(); 431 | if (err != cudaSuccess) { 432 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); 433 | // TODO(BZ) panic 434 | } 435 | } 436 | 437 | template void 438 | deformable_col2im_coord(cudaStream_t stream, const float *data_col, 439 | const float *data_im, const float *data_offset, 440 | const int channels, const int height, const int width, 441 | const int ksize_h, const int ksize_w, const int pad_h, 442 | const int pad_w, const int stride_h, const int stride_w, 443 | const int dilation_h, const int dilation_w, 444 | const int deformable_group, float *grad_offset); 445 | -------------------------------------------------------------------------------- /src/deform_conv_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | template 2 | void deformable_im2col(cudaStream_t stream, const DType *data_im, 3 | const DType *data_offset, const int channels, 4 | const int height, const int width, const int ksize_h, 5 | const int ksize_w, const int pad_h, const int pad_w, 6 | const int stride_h, const int stride_w, 7 | const int dilation_h, const int dilation_w, 8 | const int deformable_group, DType *data_col); 9 | 10 | template 11 | void deformable_col2im(cudaStream_t stream, const DType *data_col, 12 | const DType *data_offset, const int channels, 13 | const int height, const int width, const int ksize_h, 14 | const int ksize_w, const int pad_h, const int pad_w, 15 | const int stride_h, const int stride_w, 16 | const int dilation_h, const int dilation_w, 17 | const int deformable_group, DType *grad_im); 18 | 19 | template 20 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col, 21 | const DType *data_im, const DType *data_offset, 22 | const int channels, const int height, 23 | const int width, const int ksize_h, 24 | const int ksize_w, const int pad_h, 25 | const int pad_w, const int stride_h, 26 | const int stride_w, const int dilation_h, 27 | const int dilation_w, const int deformable_group, 28 | DType *grad_offset); 29 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from modules import ConvOffset2d 7 | 8 | num_deformable_groups = 2 9 | 10 | N, inC, inH, inW = 1, 6, 512, 512 11 | outC, outH, outW = 4, 512, 512 12 | kH, kW = 3, 3 13 | 14 | conv = nn.Conv2d( 15 | inC, 16 | num_deformable_groups * 2 * kH * kW, 17 | kernel_size=(kH, kW), 18 | stride=(1, 1), 19 | padding=(1, 1), 20 | bias=False).cuda() 21 | 22 | conv_offset2d = ConvOffset2d( 23 | inC, 24 | outC, (kH, kW), 25 | stride=1, 26 | padding=1, 27 | num_deformable_groups=num_deformable_groups).cuda() 28 | 29 | inputs = Variable(torch.randn(N, inC, inH, inW).cuda()) 30 | offset = conv(inputs) 31 | output = conv_offset2d(inputs, offset) 32 | output.backward(output.data) 33 | print(output.size()) 34 | --------------------------------------------------------------------------------