├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── dcn_v2.py ├── make.sh ├── setup.py ├── src ├── cpu │ ├── dcn_v2_cpu.cpp │ ├── dcn_v2_im2col_cpu.cpp │ ├── dcn_v2_im2col_cpu.h │ ├── dcn_v2_psroi_pooling_cpu.cpp │ └── vision.h ├── cuda │ ├── dcn_v2_cuda.cu │ ├── dcn_v2_im2col_cuda.cu │ ├── dcn_v2_im2col_cuda.h │ ├── dcn_v2_psroi_pooling_cuda.cu │ └── vision.h ├── dcn_v2.h └── vision.cpp └── test ├── test.py ├── testcpu.py └── testcuda.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.so 4 | *.o 5 | *pyc 6 | _ext 7 | build 8 | DCNv2.egg-info 9 | dist -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0 2 | 3 | ### Build 4 | ```bash 5 | ./make.sh # build 6 | python test.py # run examples and gradient check 7 | ``` 8 | 9 | ### An Example 10 | - deformable conv 11 | ```python 12 | from dcn_v2 import DCN 13 | input = torch.randn(2, 64, 128, 128).cuda() 14 | # wrap all things (offset and mask) in DCN 15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() 16 | output = dcn(input) 17 | print(output.shape) 18 | ``` 19 | - deformable roi pooling 20 | ```python 21 | from dcn_v2 import DCNPooling 22 | input = torch.randn(2, 32, 64, 64).cuda() 23 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 24 | x = torch.randint(256, (20, 1)).cuda().float() 25 | y = torch.randint(256, (20, 1)).cuda().float() 26 | w = torch.randint(64, (20, 1)).cuda().float() 27 | h = torch.randint(64, (20, 1)).cuda().float() 28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 29 | 30 | # mdformable pooling (V2) 31 | # wrap all things (offset and mask) in DCNPooling 32 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 33 | pooled_size=7, 34 | output_dim=32, 35 | no_trans=False, 36 | group_size=1, 37 | trans_std=0.1).cuda() 38 | 39 | dout = dpooling(input, rois) 40 | ``` 41 | ### Note 42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, 43 | ```bash 44 | git checkout pytorch_0.4 45 | ``` 46 | 47 | ### Known Issues: 48 | 49 | - [x] Gradient check w.r.t offset (solved) 50 | - [ ] Backward is not reentrant (minor) 51 | 52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). 53 | 54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. 55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some 56 | non-differential points? 57 | 58 | Update: all gradient check passes with double precision. 59 | 60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for 61 | float `<1e-15` for double), 62 | so it may not be a serious problem (?) 63 | 64 | Please post an issue or PR if you have any comments. 65 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifzhang/DCNv2/ab4d98efc0aafeb27bb05b803820385401d9921b/__init__.py -------------------------------------------------------------------------------- /dcn_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | from torch.autograd import Function 9 | from torch.autograd.function import once_differentiable 10 | from torch.nn.modules.utils import _pair 11 | 12 | import _ext as _backend 13 | 14 | 15 | class _DCNv2(Function): 16 | @staticmethod 17 | def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups): 18 | ctx.stride = _pair(stride) 19 | ctx.padding = _pair(padding) 20 | ctx.dilation = _pair(dilation) 21 | ctx.kernel_size = _pair(weight.shape[2:4]) 22 | ctx.deformable_groups = deformable_groups 23 | output = _backend.dcn_v2_forward( 24 | input, 25 | weight, 26 | bias, 27 | offset, 28 | mask, 29 | ctx.kernel_size[0], 30 | ctx.kernel_size[1], 31 | ctx.stride[0], 32 | ctx.stride[1], 33 | ctx.padding[0], 34 | ctx.padding[1], 35 | ctx.dilation[0], 36 | ctx.dilation[1], 37 | ctx.deformable_groups, 38 | ) 39 | ctx.save_for_backward(input, offset, mask, weight, bias) 40 | return output 41 | 42 | @staticmethod 43 | @once_differentiable 44 | def backward(ctx, grad_output): 45 | input, offset, mask, weight, bias = ctx.saved_tensors 46 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward( 47 | input, 48 | weight, 49 | bias, 50 | offset, 51 | mask, 52 | grad_output, 53 | ctx.kernel_size[0], 54 | ctx.kernel_size[1], 55 | ctx.stride[0], 56 | ctx.stride[1], 57 | ctx.padding[0], 58 | ctx.padding[1], 59 | ctx.dilation[0], 60 | ctx.dilation[1], 61 | ctx.deformable_groups, 62 | ) 63 | 64 | return ( 65 | grad_input, 66 | grad_offset, 67 | grad_mask, 68 | grad_weight, 69 | grad_bias, 70 | None, 71 | None, 72 | None, 73 | None, 74 | ) 75 | 76 | 77 | dcn_v2_conv = _DCNv2.apply 78 | 79 | 80 | class DCNv2(nn.Module): 81 | def __init__( 82 | self, 83 | in_channels, 84 | out_channels, 85 | kernel_size, 86 | stride, 87 | padding, 88 | dilation=1, 89 | deformable_groups=1, 90 | ): 91 | super(DCNv2, self).__init__() 92 | self.in_channels = in_channels 93 | self.out_channels = out_channels 94 | self.kernel_size = _pair(kernel_size) 95 | self.stride = _pair(stride) 96 | self.padding = _pair(padding) 97 | self.dilation = _pair(dilation) 98 | self.deformable_groups = deformable_groups 99 | 100 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 101 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 102 | self.reset_parameters() 103 | 104 | def reset_parameters(self): 105 | n = self.in_channels 106 | for k in self.kernel_size: 107 | n *= k 108 | stdv = 1.0 / math.sqrt(n) 109 | self.weight.data.uniform_(-stdv, stdv) 110 | self.bias.data.zero_() 111 | 112 | def forward(self, input, offset, mask): 113 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1] 114 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] 115 | return dcn_v2_conv( 116 | input, 117 | offset, 118 | mask, 119 | self.weight, 120 | self.bias, 121 | self.stride, 122 | self.padding, 123 | self.dilation, 124 | self.deformable_groups, 125 | ) 126 | 127 | 128 | class DCN(DCNv2): 129 | def __init__( 130 | self, 131 | in_channels, 132 | out_channels, 133 | kernel_size, 134 | stride, 135 | padding, 136 | dilation=1, 137 | deformable_groups=1, 138 | ): 139 | super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) 140 | 141 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 142 | self.conv_offset_mask = nn.Conv2d( 143 | self.in_channels, 144 | channels_, 145 | kernel_size=self.kernel_size, 146 | stride=self.stride, 147 | padding=self.padding, 148 | bias=True, 149 | ) 150 | self.init_offset() 151 | 152 | def init_offset(self): 153 | self.conv_offset_mask.weight.data.zero_() 154 | self.conv_offset_mask.bias.data.zero_() 155 | 156 | def forward(self, input): 157 | out = self.conv_offset_mask(input) 158 | o1, o2, mask = torch.chunk(out, 3, dim=1) 159 | offset = torch.cat((o1, o2), dim=1) 160 | mask = torch.sigmoid(mask) 161 | return dcn_v2_conv( 162 | input, 163 | offset, 164 | mask, 165 | self.weight, 166 | self.bias, 167 | self.stride, 168 | self.padding, 169 | self.dilation, 170 | self.deformable_groups, 171 | ) 172 | 173 | 174 | class _DCNv2Pooling(Function): 175 | @staticmethod 176 | def forward( 177 | ctx, 178 | input, 179 | rois, 180 | offset, 181 | spatial_scale, 182 | pooled_size, 183 | output_dim, 184 | no_trans, 185 | group_size=1, 186 | part_size=None, 187 | sample_per_part=4, 188 | trans_std=0.0, 189 | ): 190 | ctx.spatial_scale = spatial_scale 191 | ctx.no_trans = int(no_trans) 192 | ctx.output_dim = output_dim 193 | ctx.group_size = group_size 194 | ctx.pooled_size = pooled_size 195 | ctx.part_size = pooled_size if part_size is None else part_size 196 | ctx.sample_per_part = sample_per_part 197 | ctx.trans_std = trans_std 198 | 199 | output, output_count = _backend.dcn_v2_psroi_pooling_forward( 200 | input, 201 | rois, 202 | offset, 203 | ctx.no_trans, 204 | ctx.spatial_scale, 205 | ctx.output_dim, 206 | ctx.group_size, 207 | ctx.pooled_size, 208 | ctx.part_size, 209 | ctx.sample_per_part, 210 | ctx.trans_std, 211 | ) 212 | ctx.save_for_backward(input, rois, offset, output_count) 213 | return output 214 | 215 | @staticmethod 216 | @once_differentiable 217 | def backward(ctx, grad_output): 218 | input, rois, offset, output_count = ctx.saved_tensors 219 | grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward( 220 | grad_output, 221 | input, 222 | rois, 223 | offset, 224 | output_count, 225 | ctx.no_trans, 226 | ctx.spatial_scale, 227 | ctx.output_dim, 228 | ctx.group_size, 229 | ctx.pooled_size, 230 | ctx.part_size, 231 | ctx.sample_per_part, 232 | ctx.trans_std, 233 | ) 234 | 235 | return grad_input, None, grad_offset, None, None, None, None, None, None, None, None 236 | 237 | 238 | dcn_v2_pooling = _DCNv2Pooling.apply 239 | 240 | 241 | class DCNv2Pooling(nn.Module): 242 | def __init__( 243 | self, 244 | spatial_scale, 245 | pooled_size, 246 | output_dim, 247 | no_trans, 248 | group_size=1, 249 | part_size=None, 250 | sample_per_part=4, 251 | trans_std=0.0, 252 | ): 253 | super(DCNv2Pooling, self).__init__() 254 | self.spatial_scale = spatial_scale 255 | self.pooled_size = pooled_size 256 | self.output_dim = output_dim 257 | self.no_trans = no_trans 258 | self.group_size = group_size 259 | self.part_size = pooled_size if part_size is None else part_size 260 | self.sample_per_part = sample_per_part 261 | self.trans_std = trans_std 262 | 263 | def forward(self, input, rois, offset): 264 | assert input.shape[1] == self.output_dim 265 | if self.no_trans: 266 | offset = input.new() 267 | return dcn_v2_pooling( 268 | input, 269 | rois, 270 | offset, 271 | self.spatial_scale, 272 | self.pooled_size, 273 | self.output_dim, 274 | self.no_trans, 275 | self.group_size, 276 | self.part_size, 277 | self.sample_per_part, 278 | self.trans_std, 279 | ) 280 | 281 | 282 | class DCNPooling(DCNv2Pooling): 283 | def __init__( 284 | self, 285 | spatial_scale, 286 | pooled_size, 287 | output_dim, 288 | no_trans, 289 | group_size=1, 290 | part_size=None, 291 | sample_per_part=4, 292 | trans_std=0.0, 293 | deform_fc_dim=1024, 294 | ): 295 | super(DCNPooling, self).__init__( 296 | spatial_scale, 297 | pooled_size, 298 | output_dim, 299 | no_trans, 300 | group_size, 301 | part_size, 302 | sample_per_part, 303 | trans_std, 304 | ) 305 | 306 | self.deform_fc_dim = deform_fc_dim 307 | 308 | if not no_trans: 309 | self.offset_mask_fc = nn.Sequential( 310 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), 311 | nn.ReLU(inplace=True), 312 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), 313 | nn.ReLU(inplace=True), 314 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3), 315 | ) 316 | self.offset_mask_fc[4].weight.data.zero_() 317 | self.offset_mask_fc[4].bias.data.zero_() 318 | 319 | def forward(self, input, rois): 320 | offset = input.new() 321 | 322 | if not self.no_trans: 323 | 324 | # do roi_align first 325 | n = rois.shape[0] 326 | roi = dcn_v2_pooling( 327 | input, 328 | rois, 329 | offset, 330 | self.spatial_scale, 331 | self.pooled_size, 332 | self.output_dim, 333 | True, # no trans 334 | self.group_size, 335 | self.part_size, 336 | self.sample_per_part, 337 | self.trans_std, 338 | ) 339 | 340 | # build mask and offset 341 | offset_mask = self.offset_mask_fc(roi.view(n, -1)) 342 | offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size) 343 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) 344 | offset = torch.cat((o1, o2), dim=1) 345 | mask = torch.sigmoid(mask) 346 | 347 | # do pooling with offset and mask 348 | return ( 349 | dcn_v2_pooling( 350 | input, 351 | rois, 352 | offset, 353 | self.spatial_scale, 354 | self.pooled_size, 355 | self.output_dim, 356 | self.no_trans, 357 | self.group_size, 358 | self.part_size, 359 | self.sample_per_part, 360 | self.trans_std, 361 | ) 362 | * mask 363 | ) 364 | # only roi_align 365 | return dcn_v2_pooling( 366 | input, 367 | rois, 368 | offset, 369 | self.spatial_scale, 370 | self.pooled_size, 371 | self.output_dim, 372 | self.no_trans, 373 | self.group_size, 374 | self.part_size, 375 | self.sample_per_part, 376 | self.trans_std, 377 | ) 378 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python setup.py build develop 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | import sys 6 | 7 | import torch 8 | from setuptools import find_packages, setup 9 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 10 | 11 | requirements = ["torch", "torchvision"] 12 | 13 | 14 | def get_extensions(): 15 | this_dir = os.path.dirname(os.path.abspath(__file__)) 16 | extensions_dir = os.path.join(this_dir, "src") 17 | 18 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 19 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 20 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 21 | 22 | os.environ["CC"] = "g++" 23 | sources = main_file + source_cpu 24 | extension = CppExtension 25 | extra_compile_args = {"cxx": []} 26 | define_macros = [] 27 | 28 | if torch.cuda.is_available() and CUDA_HOME is not None: 29 | extension = CUDAExtension 30 | sources += source_cuda 31 | define_macros += [("WITH_CUDA", None)] 32 | extra_compile_args["nvcc"] = [ 33 | "-DCUDA_HAS_FP16=1", 34 | "-D__CUDA_NO_HALF_OPERATORS__", 35 | "-D__CUDA_NO_HALF_CONVERSIONS__", 36 | "-D__CUDA_NO_HALF2_OPERATORS__", 37 | ] 38 | else: 39 | # raise NotImplementedError('Cuda is not available') 40 | pass 41 | 42 | extra_compile_args['cxx'].append('-fopenmp') 43 | 44 | sources = [os.path.join(extensions_dir, s) for s in sources] 45 | include_dirs = [extensions_dir] 46 | ext_modules = [ 47 | extension( 48 | "_ext", 49 | sources, 50 | include_dirs=include_dirs, 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args, 53 | ) 54 | ] 55 | return ext_modules 56 | 57 | 58 | setup( 59 | name="DCNv2", 60 | version="0.1", 61 | author="charlesshang", 62 | url="https://github.com/charlesshang/DCNv2", 63 | description="deformable convolutional networks", 64 | packages=find_packages( 65 | exclude=( 66 | "configs", 67 | "tests", 68 | ) 69 | ), 70 | # install_requires=requirements, 71 | ext_modules=get_extensions(), 72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 73 | ) 74 | -------------------------------------------------------------------------------- /src/cpu/dcn_v2_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cpu/dcn_v2_im2col_cpu.h" 3 | #include 4 | 5 | #include 6 | //#include 7 | 8 | #include 9 | //#include 10 | //#include 11 | 12 | //extern THCState *state; 13 | 14 | // author: Charles Shang 15 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 16 | 17 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 18 | 19 | // edit by: James Bockman and Matthew Howe 20 | // modified for torch implementation to remove use of deprecated torch access to Blas 21 | 22 | at::Tensor 23 | dcn_v2_cpu_forward(const at::Tensor &input, 24 | const at::Tensor &weight, 25 | const at::Tensor &bias, 26 | const at::Tensor &offset, 27 | const at::Tensor &mask, 28 | const int kernel_h, 29 | const int kernel_w, 30 | const int stride_h, 31 | const int stride_w, 32 | const int pad_h, 33 | const int pad_w, 34 | const int dilation_h, 35 | const int dilation_w, 36 | const int deformable_group) 37 | { 38 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); 39 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 40 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); 41 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); 42 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); 43 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ 44 | 45 | const int batch = input.size(0); 46 | const int channels = input.size(1); 47 | const int height = input.size(2); 48 | const int width = input.size(3); 49 | 50 | const int channels_out = weight.size(0); 51 | const int channels_kernel = weight.size(1); 52 | const int kernel_h_ = weight.size(2); 53 | const int kernel_w_ = weight.size(3); 54 | 55 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); 56 | // printf("Channels: %d %d\n", channels, channels_kernel); 57 | // printf("Channels: %d %d\n", channels_out, channels_kernel); 58 | 59 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 60 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 61 | 62 | AT_ASSERTM(channels == channels_kernel, 63 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 64 | 65 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 66 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 67 | 68 | // auto ones = at::ones({height_out, width_out}, input.options()); 69 | auto ones = at::ones({bias.sizes()[0], height_out, width_out}, input.options()); 70 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 71 | auto output = at::zeros({batch, channels_out, height_out, width_out}, input.options()); 72 | 73 | using scalar_t = float; 74 | for (int b = 0; b < batch; b++) 75 | { 76 | auto input_n = input.select(0, b); 77 | auto offset_n = offset.select(0, b); 78 | auto mask_n = mask.select(0, b); 79 | auto output_n = output.select(0, b); 80 | // std::cout << "output_n: " << output_n << "output.select(0,b): " << output.select(0,b) << "\n"; 81 | 82 | // Do Bias first: 83 | // M,N,K are dims of matrix A and B 84 | // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) 85 | // (N x 1) (1 x M) 86 | 87 | // torch implementation 88 | auto ones_T = at::transpose(ones.contiguous(), 2, 0); 89 | ones_T = at::mul(ones_T, bias.contiguous()); 90 | ones_T = at::transpose(ones_T, 2, 0); 91 | output_n = at::add(output_n, ones_T); 92 | 93 | modulated_deformable_im2col_cpu(input_n.data_ptr(), 94 | offset_n.data_ptr(), 95 | mask_n.data_ptr(), 96 | 1, channels, height, width, 97 | height_out, width_out, kernel_h, kernel_w, 98 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 99 | deformable_group, 100 | columns.data_ptr()); 101 | 102 | //(k * m) x (m * n) 103 | // Y = WC 104 | 105 | // torch implementation 106 | auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w}); 107 | auto product = at::matmul(weight_flat, columns); 108 | output.select(0, b) = at::add(output_n, product.view({channels_out, height_out, width_out})); 109 | } 110 | return output; 111 | } 112 | 113 | std::vector dcn_v2_cpu_backward(const at::Tensor &input, 114 | const at::Tensor &weight, 115 | const at::Tensor &bias, 116 | const at::Tensor &offset, 117 | const at::Tensor &mask, 118 | const at::Tensor &grad_output, 119 | int kernel_h, int kernel_w, 120 | int stride_h, int stride_w, 121 | int pad_h, int pad_w, 122 | int dilation_h, int dilation_w, 123 | int deformable_group) 124 | { 125 | 126 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); 127 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); 128 | 129 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 130 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); 131 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); 132 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); 133 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ 134 | 135 | const int batch = input.size(0); 136 | const int channels = input.size(1); 137 | const int height = input.size(2); 138 | const int width = input.size(3); 139 | 140 | const int channels_out = weight.size(0); 141 | const int channels_kernel = weight.size(1); 142 | const int kernel_h_ = weight.size(2); 143 | const int kernel_w_ = weight.size(3); 144 | 145 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 146 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 147 | 148 | AT_ASSERTM(channels == channels_kernel, 149 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 150 | 151 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 152 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 153 | 154 | auto ones = at::ones({height_out, width_out}, input.options()); 155 | auto columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 156 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 157 | 158 | auto grad_input = at::zeros_like(input); 159 | auto grad_weight = at::zeros_like(weight); 160 | auto grad_bias = at::zeros_like(bias); 161 | auto grad_offset = at::zeros_like(offset); 162 | auto grad_mask = at::zeros_like(mask); 163 | 164 | using scalar_t = float; 165 | 166 | for (int b = 0; b < batch; b++) 167 | { 168 | auto input_n = input.select(0, b); 169 | auto offset_n = offset.select(0, b); 170 | auto mask_n = mask.select(0, b); 171 | auto grad_output_n = grad_output.select(0, b); 172 | auto grad_input_n = grad_input.select(0, b); 173 | auto grad_offset_n = grad_offset.select(0, b); 174 | auto grad_mask_n = grad_mask.select(0, b); 175 | 176 | 177 | 178 | // Torch implementation 179 | auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w}); 180 | weight_flat = at::transpose(weight_flat, 1, 0); 181 | auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out}); 182 | columns = at::matmul(weight_flat, grad_output_n_flat); 183 | 184 | // gradient w.r.t. input coordinate data 185 | modulated_deformable_col2im_coord_cpu(columns.data_ptr(), 186 | input_n.data_ptr(), 187 | offset_n.data_ptr(), 188 | mask_n.data_ptr(), 189 | 1, channels, height, width, 190 | height_out, width_out, kernel_h, kernel_w, 191 | pad_h, pad_w, stride_h, stride_w, 192 | dilation_h, dilation_w, deformable_group, 193 | grad_offset_n.data_ptr(), 194 | grad_mask_n.data_ptr()); 195 | // gradient w.r.t. input data 196 | modulated_deformable_col2im_cpu(columns.data_ptr(), 197 | offset_n.data_ptr(), 198 | mask_n.data_ptr(), 199 | 1, channels, height, width, 200 | height_out, width_out, kernel_h, kernel_w, 201 | pad_h, pad_w, stride_h, stride_w, 202 | dilation_h, dilation_w, deformable_group, 203 | grad_input_n.data_ptr()); 204 | 205 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group 206 | modulated_deformable_im2col_cpu(input_n.data_ptr(), 207 | offset_n.data_ptr(), 208 | mask_n.data_ptr(), 209 | 1, channels, height, width, 210 | height_out, width_out, kernel_h, kernel_w, 211 | pad_h, pad_w, stride_h, stride_w, 212 | dilation_h, dilation_w, deformable_group, 213 | columns.data_ptr()); 214 | 215 | // Torch implementation 216 | auto product = at::matmul(grad_output_n_flat, at::transpose(columns, 1, 0)); 217 | grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w})); 218 | 219 | 220 | // Torch implementation 221 | auto ones_flat = ones.view({height_out*width_out}); 222 | product = at::matmul(grad_output_n_flat, ones_flat); 223 | grad_bias = at::add(grad_bias, product); 224 | } 225 | 226 | return { 227 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias 228 | }; 229 | } 230 | -------------------------------------------------------------------------------- /src/cpu/dcn_v2_im2col_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "dcn_v2_im2col_cpu.h" 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | //#include 8 | 9 | #include 10 | //#include 11 | //#include 12 | 13 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 14 | 15 | /*#define CUDA_KERNEL_LOOP(i, n) \ 16 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 17 | i < (n); \ 18 | i += blockDim.x * gridDim.x) 19 | 20 | const int CUDA_NUM_THREADS = 1024; 21 | inline int GET_BLOCKS(const int N) 22 | { 23 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 24 | }*/ 25 | 26 | 27 | float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width, 28 | const int height, const int width, float h, float w) 29 | { 30 | int h_low = floor(h); 31 | int w_low = floor(w); 32 | int h_high = h_low + 1; 33 | int w_high = w_low + 1; 34 | 35 | float lh = h - h_low; 36 | float lw = w - w_low; 37 | float hh = 1 - lh, hw = 1 - lw; 38 | 39 | float v1 = 0; 40 | if (h_low >= 0 && w_low >= 0) 41 | v1 = bottom_data[h_low * data_width + w_low]; 42 | float v2 = 0; 43 | if (h_low >= 0 && w_high <= width - 1) 44 | v2 = bottom_data[h_low * data_width + w_high]; 45 | float v3 = 0; 46 | if (h_high <= height - 1 && w_low >= 0) 47 | v3 = bottom_data[h_high * data_width + w_low]; 48 | float v4 = 0; 49 | if (h_high <= height - 1 && w_high <= width - 1) 50 | v4 = bottom_data[h_high * data_width + w_high]; 51 | 52 | float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 53 | 54 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 55 | return val; 56 | } 57 | 58 | float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w, 59 | const int h, const int w, const int height, const int width) 60 | { 61 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 62 | { 63 | //empty 64 | return 0; 65 | } 66 | 67 | int argmax_h_low = floor(argmax_h); 68 | int argmax_w_low = floor(argmax_w); 69 | int argmax_h_high = argmax_h_low + 1; 70 | int argmax_w_high = argmax_w_low + 1; 71 | 72 | float weight = 0; 73 | if (h == argmax_h_low && w == argmax_w_low) 74 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 75 | if (h == argmax_h_low && w == argmax_w_high) 76 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 77 | if (h == argmax_h_high && w == argmax_w_low) 78 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 79 | if (h == argmax_h_high && w == argmax_w_high) 80 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 81 | return weight; 82 | } 83 | 84 | float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w, 85 | const int height, const int width, const float *im_data, 86 | const int data_width, const int bp_dir) 87 | { 88 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 89 | { 90 | //empty 91 | return 0; 92 | } 93 | 94 | int argmax_h_low = floor(argmax_h); 95 | int argmax_w_low = floor(argmax_w); 96 | int argmax_h_high = argmax_h_low + 1; 97 | int argmax_w_high = argmax_w_low + 1; 98 | 99 | float weight = 0; 100 | 101 | if (bp_dir == 0) 102 | { 103 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 104 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 105 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 106 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 107 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 108 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 109 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 110 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 111 | } 112 | else if (bp_dir == 1) 113 | { 114 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 115 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 116 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 117 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 118 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 119 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 120 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 121 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 122 | } 123 | 124 | return weight; 125 | } 126 | 127 | void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, 128 | const int height, const int width, const int kernel_h, const int kernel_w, 129 | const int pad_h, const int pad_w, 130 | const int stride_h, const int stride_w, 131 | const int dilation_h, const int dilation_w, 132 | const int channel_per_deformable_group, 133 | const int batch_size, const int num_channels, const int deformable_group, 134 | const int height_col, const int width_col, 135 | float *data_col) 136 | { 137 | // launch channels * batch_size * height_col * width_col cores 138 | for(int index=0; index(0); 178 | const float h_im = h_in + i * dilation_h + offset_h; 179 | const float w_im = w_in + j * dilation_w + offset_w; 180 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 181 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) 182 | { 183 | //const float map_h = i * dilation_h + offset_h; 184 | //const float map_w = j * dilation_w + offset_w; 185 | //const int cur_height = height - h_in; 186 | //const int cur_width = width - w_in; 187 | //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 188 | val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im); 189 | } 190 | *data_col_ptr = val * mask; 191 | // data_col_ptr += batch_size * height_col * width_col; 192 | data_col_ptr += height_col * width_col; 193 | } 194 | } 195 | } 196 | } 197 | 198 | void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, 199 | const int channels, const int height, const int width, 200 | const int kernel_h, const int kernel_w, 201 | const int pad_h, const int pad_w, 202 | const int stride_h, const int stride_w, 203 | const int dilation_h, const int dilation_w, 204 | const int channel_per_deformable_group, 205 | const int batch_size, const int deformable_group, 206 | const int height_col, const int width_col, 207 | float *grad_im) 208 | { 209 | for(int index = 0; index < n; index++) 210 | { 211 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 212 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 213 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 214 | // compute the start and end of the output 215 | 216 | const int deformable_group_index = c / channel_per_deformable_group; 217 | 218 | int w_out = index % width_col; 219 | int h_out = (index / width_col) % height_col; 220 | int b = (index / width_col / height_col) % batch_size; 221 | int w_in = w_out * stride_w - pad_w; 222 | int h_in = h_out * stride_h - pad_h; 223 | 224 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 225 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 226 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 227 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 228 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; 229 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 230 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 231 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 232 | const float cur_inv_h_data = h_in + i * dilation_h + offset_h; 233 | const float cur_inv_w_data = w_in + j * dilation_w + offset_w; 234 | 235 | const float cur_top_grad = data_col[index] * mask; 236 | const int cur_h = (int)cur_inv_h_data; 237 | const int cur_w = (int)cur_inv_w_data; 238 | 239 | for (int dy = -2; dy <= 2; dy++) 240 | { 241 | for (int dx = -2; dx <= 2; dx++) 242 | { 243 | if (cur_h + dy >= 0 && cur_h + dy < height && 244 | cur_w + dx >= 0 && cur_w + dx < width && 245 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 246 | abs(cur_inv_w_data - (cur_w + dx)) < 1) 247 | { 248 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 249 | float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 250 | //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 251 | *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; 252 | 253 | } 254 | } 255 | } 256 | } 257 | } 258 | 259 | void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im, 260 | const float *data_offset, const float *data_mask, 261 | const int channels, const int height, const int width, 262 | const int kernel_h, const int kernel_w, 263 | const int pad_h, const int pad_w, 264 | const int stride_h, const int stride_w, 265 | const int dilation_h, const int dilation_w, 266 | const int channel_per_deformable_group, 267 | const int batch_size, const int offset_channels, const int deformable_group, 268 | const int height_col, const int width_col, 269 | float *grad_offset, float *grad_mask) 270 | { 271 | for(int index = 0; index < n; index++) 272 | { 273 | float val = 0, mval = 0; 274 | int w = index % width_col; 275 | int h = (index / width_col) % height_col; 276 | int c = (index / width_col / height_col) % offset_channels; 277 | int b = (index / width_col / height_col) / offset_channels; 278 | // compute the start and end of the output 279 | 280 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 281 | const int col_step = kernel_h * kernel_w; 282 | int cnt = 0; 283 | const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; 284 | const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; 285 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 286 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 287 | 288 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 289 | 290 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) 291 | { 292 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 293 | const int bp_dir = offset_c % 2; 294 | 295 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 296 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 297 | int w_out = col_pos % width_col; 298 | int h_out = (col_pos / width_col) % height_col; 299 | int w_in = w_out * stride_w - pad_w; 300 | int h_in = h_out * stride_h - pad_h; 301 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 302 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 303 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); 304 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 305 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 306 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 307 | float inv_h = h_in + i * dilation_h + offset_h; 308 | float inv_w = w_in + j * dilation_w + offset_w; 309 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) 310 | { 311 | inv_h = inv_w = -2; 312 | } 313 | else 314 | { 315 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); 316 | } 317 | const float weight = dmcn_get_coordinate_weight_cpu( 318 | inv_h, inv_w, 319 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 320 | val += weight * data_col_ptr[col_pos] * mask; 321 | cnt += 1; 322 | } 323 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val); 324 | grad_offset[index] = val; 325 | if (offset_c % 2 == 0) 326 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); 327 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; 328 | } 329 | } 330 | 331 | void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask, 332 | const int batch_size, const int channels, const int height_im, const int width_im, 333 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 334 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 335 | const int dilation_h, const int dilation_w, 336 | const int deformable_group, float* data_col) { 337 | // num_axes should be smaller than block size 338 | const int channel_per_deformable_group = channels / deformable_group; 339 | const int num_kernels = channels * batch_size * height_col * width_col; 340 | modulated_deformable_im2col_cpu_kernel( 341 | num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, 342 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, 343 | batch_size, channels, deformable_group, height_col, width_col, data_col); 344 | 345 | /*cudaError_t err = cudaGetLastError(); 346 | if (err != cudaSuccess) 347 | { 348 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); 349 | }*/ 350 | 351 | } 352 | 353 | void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask, 354 | const int batch_size, const int channels, const int height_im, const int width_im, 355 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 356 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 357 | const int dilation_h, const int dilation_w, 358 | const int deformable_group, float* grad_im){ 359 | 360 | const int channel_per_deformable_group = channels / deformable_group; 361 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; 362 | modulated_deformable_col2im_cpu_kernel( 363 | num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, 364 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, 365 | dilation_h, dilation_w, channel_per_deformable_group, 366 | batch_size, deformable_group, height_col, width_col, grad_im); 367 | /*cudaError_t err = cudaGetLastError(); 368 | if (err != cudaSuccess) 369 | { 370 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); 371 | }*/ 372 | 373 | } 374 | 375 | void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, 376 | const int batch_size, const int channels, const int height_im, const int width_im, 377 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 378 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 379 | const int dilation_h, const int dilation_w, 380 | const int deformable_group, 381 | float* grad_offset, float* grad_mask) { 382 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; 383 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; 384 | modulated_deformable_col2im_coord_cpu_kernel( 385 | num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, 386 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 387 | dilation_h, dilation_w, channel_per_deformable_group, 388 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, 389 | grad_offset, grad_mask); 390 | /*cudaError_t err = cudaGetLastError(); 391 | if (err != cudaSuccess) 392 | { 393 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); 394 | }*/ 395 | } -------------------------------------------------------------------------------- /src/cpu/dcn_v2_im2col_cpu.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 64 | 65 | #ifndef DCN_V2_IM2COL_CPU 66 | #define DCN_V2_IM2COL_CPU 67 | 68 | #ifdef __cplusplus 69 | extern "C" 70 | { 71 | #endif 72 | 73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 88 | const int batch_size, const int channels, const int height_im, const int width_im, 89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 91 | const int dilation_h, const int dilation_w, 92 | const int deformable_group, 93 | float *grad_offset, float *grad_mask); 94 | 95 | #ifdef __cplusplus 96 | } 97 | #endif 98 | 99 | #endif -------------------------------------------------------------------------------- /src/cpu/dcn_v2_psroi_pooling_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | //#include 17 | 18 | #include 19 | //#include 20 | //#include 21 | 22 | /*#define CUDA_KERNEL_LOOP(i, n) \ 23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 24 | i < (n); \ 25 | i += blockDim.x * gridDim.x) 26 | 27 | const int CUDA_NUM_THREADS = 1024; 28 | inline int GET_BLOCKS(const int N) 29 | { 30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 31 | }*/ 32 | 33 | template 34 | T bilinear_interp_cpu( 35 | const T *data, 36 | const T x, 37 | const T y, 38 | const int width, 39 | const int height) 40 | { 41 | int x1 = floor(x); 42 | int x2 = ceil(x); 43 | int y1 = floor(y); 44 | int y2 = ceil(y); 45 | T dist_x = static_cast(x - x1); 46 | T dist_y = static_cast(y - y1); 47 | T value11 = data[y1 * width + x1]; 48 | T value12 = data[y2 * width + x1]; 49 | T value21 = data[y1 * width + x2]; 50 | T value22 = data[y2 * width + x2]; 51 | T value = (1 - dist_x) * (1 - dist_y) * value11 + 52 | (1 - dist_x) * dist_y * value12 + 53 | dist_x * (1 - dist_y) * value21 + 54 | dist_x * dist_y * value22; 55 | return value; 56 | } 57 | 58 | template 59 | void DeformablePSROIPoolForwardKernelCpu( 60 | const int count, 61 | const T *bottom_data, 62 | const T spatial_scale, 63 | const int channels, 64 | const int height, const int width, 65 | const int pooled_height, const int pooled_width, 66 | const T *bottom_rois, const T *bottom_trans, 67 | const int no_trans, 68 | const T trans_std, 69 | const int sample_per_part, 70 | const int output_dim, 71 | const int group_size, 72 | const int part_size, 73 | const int num_classes, 74 | const int channels_each_class, 75 | T *top_data, 76 | T *top_count) 77 | { 78 | for(int index = 0; index < count; index++) 79 | { 80 | // The output is in order (n, ctop, ph, pw) 81 | int pw = index % pooled_width; 82 | int ph = (index / pooled_width) % pooled_height; 83 | int ctop = (index / pooled_width / pooled_height) % output_dim; 84 | int n = index / pooled_width / pooled_height / output_dim; 85 | 86 | // [start, end) interval for spatial sampling 87 | const T *offset_bottom_rois = bottom_rois + n * 5; 88 | int roi_batch_ind = offset_bottom_rois[0]; 89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 93 | 94 | // Force too small ROIs to be 1x1 95 | T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 96 | T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); 97 | 98 | // Compute w and h at bottom 99 | T bin_size_h = roi_height / static_cast(pooled_height); 100 | T bin_size_w = roi_width / static_cast(pooled_width); 101 | 102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 104 | 105 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 106 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 107 | int class_id = ctop / channels_each_class; 108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 110 | 111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 112 | wstart += trans_x * roi_width; 113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 114 | hstart += trans_y * roi_height; 115 | 116 | T sum = 0; 117 | int count = 0; 118 | int gw = floor(static_cast(pw) * group_size / pooled_width); 119 | int gh = floor(static_cast(ph) * group_size / pooled_height); 120 | gw = std::min(std::max(gw, 0), group_size - 1); 121 | gh = std::min(std::max(gh, 0), group_size - 1); 122 | 123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 124 | for (int ih = 0; ih < sample_per_part; ih++) 125 | { 126 | for (int iw = 0; iw < sample_per_part; iw++) 127 | { 128 | T w = wstart + iw * sub_bin_size_w; 129 | T h = hstart + ih * sub_bin_size_h; 130 | // bilinear interpolation 131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 132 | { 133 | continue; 134 | } 135 | w = std::min(std::max(w, T(0.)), width - T(1.)); 136 | h = std::min(std::max(h, T(0.)), height - T(1.)); 137 | int c = (ctop * group_size + gh) * group_size + gw; 138 | T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height); 139 | sum += val; 140 | count++; 141 | } 142 | } 143 | top_data[index] = count == 0 ? static_cast(0) : sum / count; 144 | top_count[index] = count; 145 | } 146 | } 147 | 148 | template 149 | void DeformablePSROIPoolBackwardAccKernelCpu( 150 | const int count, 151 | const T *top_diff, 152 | const T *top_count, 153 | const int num_rois, 154 | const T spatial_scale, 155 | const int channels, 156 | const int height, const int width, 157 | const int pooled_height, const int pooled_width, 158 | const int output_dim, 159 | T *bottom_data_diff, T *bottom_trans_diff, 160 | const T *bottom_data, 161 | const T *bottom_rois, 162 | const T *bottom_trans, 163 | const int no_trans, 164 | const T trans_std, 165 | const int sample_per_part, 166 | const int group_size, 167 | const int part_size, 168 | const int num_classes, 169 | const int channels_each_class) 170 | { 171 | for(int index = 0; index < count; index++) 172 | { 173 | // The output is in order (n, ctop, ph, pw) 174 | int pw = index % pooled_width; 175 | int ph = (index / pooled_width) % pooled_height; 176 | int ctop = (index / pooled_width / pooled_height) % output_dim; 177 | int n = index / pooled_width / pooled_height / output_dim; 178 | 179 | // [start, end) interval for spatial sampling 180 | const T *offset_bottom_rois = bottom_rois + n * 5; 181 | int roi_batch_ind = offset_bottom_rois[0]; 182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 186 | 187 | // Force too small ROIs to be 1x1 188 | T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 189 | T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); 190 | 191 | // Compute w and h at bottom 192 | T bin_size_h = roi_height / static_cast(pooled_height); 193 | T bin_size_w = roi_width / static_cast(pooled_width); 194 | 195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 197 | 198 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 199 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 200 | int class_id = ctop / channels_each_class; 201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 203 | 204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 205 | wstart += trans_x * roi_width; 206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 207 | hstart += trans_y * roi_height; 208 | 209 | if (top_count[index] <= 0) 210 | { 211 | continue; 212 | } 213 | T diff_val = top_diff[index] / top_count[index]; 214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 216 | int gw = floor(static_cast(pw) * group_size / pooled_width); 217 | int gh = floor(static_cast(ph) * group_size / pooled_height); 218 | gw = std::min(std::max(gw, 0), group_size - 1); 219 | gh = std::min(std::max(gh, 0), group_size - 1); 220 | 221 | for (int ih = 0; ih < sample_per_part; ih++) 222 | { 223 | for (int iw = 0; iw < sample_per_part; iw++) 224 | { 225 | T w = wstart + iw * sub_bin_size_w; 226 | T h = hstart + ih * sub_bin_size_h; 227 | // bilinear interpolation 228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 229 | { 230 | continue; 231 | } 232 | w = std::min(std::max(w, T(0.)), width - T(1.)); 233 | h = std::min(std::max(h, T(0.)), height - T(1.)); 234 | int c = (ctop * group_size + gh) * group_size + gw; 235 | // backward on feature 236 | int x0 = floor(w); 237 | int x1 = ceil(w); 238 | int y0 = floor(h); 239 | int y1 = ceil(h); 240 | T dist_x = w - x0, dist_y = h - y0; 241 | T q00 = (1 - dist_x) * (1 - dist_y); 242 | T q01 = (1 - dist_x) * dist_y; 243 | T q10 = dist_x * (1 - dist_y); 244 | T q11 = dist_x * dist_y; 245 | int bottom_index_base = c * height * width; 246 | /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/ 250 | *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val; 251 | *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val; 252 | *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val; 253 | *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val; 254 | 255 | 256 | if (no_trans) 257 | { 258 | continue; 259 | } 260 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 261 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 262 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 263 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 264 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; 265 | diff_x *= roi_width; 266 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; 267 | diff_y *= roi_height; 268 | 269 | /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); 270 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/ 271 | *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x; 272 | *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y; 273 | } 274 | } 275 | } 276 | } 277 | 278 | std::tuple 279 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 280 | const at::Tensor &bbox, 281 | const at::Tensor &trans, 282 | const int no_trans, 283 | const float spatial_scale, 284 | const int output_dim, 285 | const int group_size, 286 | const int pooled_size, 287 | const int part_size, 288 | const int sample_per_part, 289 | const float trans_std) 290 | { 291 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 292 | AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor"); 293 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");*/ 294 | 295 | // const int batch = input.size(0); 296 | const int channels = input.size(1); 297 | const int height = input.size(2); 298 | const int width = input.size(3); 299 | const int channels_trans = no_trans ? 2 : trans.size(1); 300 | const int num_bbox = bbox.size(0); 301 | 302 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 303 | auto pooled_height = pooled_size; 304 | auto pooled_width = pooled_size; 305 | 306 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 307 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 308 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 309 | 310 | const int num_classes = no_trans ? 1 : channels_trans / 2; 311 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 312 | 313 | //cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 314 | 315 | if (out.numel() == 0) 316 | { 317 | //THCudaCheck(cudaGetLastError()); 318 | return std::make_tuple(out, top_count); 319 | } 320 | 321 | /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 322 | dim3 block(512);*/ 323 | 324 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { 325 | DeformablePSROIPoolForwardKernelCpu( 326 | out_size, 327 | input.contiguous().data_ptr(), 328 | spatial_scale, 329 | channels, 330 | height, width, 331 | pooled_height, 332 | pooled_width, 333 | bbox.contiguous().data_ptr(), 334 | trans.contiguous().data_ptr(), 335 | no_trans, 336 | trans_std, 337 | sample_per_part, 338 | output_dim, 339 | group_size, 340 | part_size, 341 | num_classes, 342 | channels_each_class, 343 | out.data_ptr(), 344 | top_count.data_ptr()); 345 | }); 346 | //THCudaCheck(cudaGetLastError()); 347 | return std::make_tuple(out, top_count); 348 | } 349 | 350 | std::tuple 351 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 352 | const at::Tensor &input, 353 | const at::Tensor &bbox, 354 | const at::Tensor &trans, 355 | const at::Tensor &top_count, 356 | const int no_trans, 357 | const float spatial_scale, 358 | const int output_dim, 359 | const int group_size, 360 | const int pooled_size, 361 | const int part_size, 362 | const int sample_per_part, 363 | const float trans_std) 364 | { 365 | /*AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor"); 366 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 367 | AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor"); 368 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); 369 | AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor");*/ 370 | 371 | const int batch = input.size(0); 372 | const int channels = input.size(1); 373 | const int height = input.size(2); 374 | const int width = input.size(3); 375 | const int channels_trans = no_trans ? 2 : trans.size(1); 376 | const int num_bbox = bbox.size(0); 377 | 378 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 379 | auto pooled_height = pooled_size; 380 | auto pooled_width = pooled_size; 381 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 382 | const int num_classes = no_trans ? 1 : channels_trans / 2; 383 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 384 | 385 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); 386 | auto trans_grad = at::zeros_like(trans); 387 | 388 | if (input_grad.numel() == 0) 389 | { 390 | //THCudaCheck(cudaGetLastError()); 391 | return std::make_tuple(input_grad, trans_grad); 392 | } 393 | 394 | /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 395 | dim3 block(512); 396 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/ 397 | 398 | AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { 399 | DeformablePSROIPoolBackwardAccKernelCpu( 400 | out_size, 401 | out_grad.contiguous().data_ptr(), 402 | top_count.contiguous().data_ptr(), 403 | num_bbox, 404 | spatial_scale, 405 | channels, 406 | height, 407 | width, 408 | pooled_height, 409 | pooled_width, 410 | output_dim, 411 | input_grad.contiguous().data_ptr(), 412 | trans_grad.contiguous().data_ptr(), 413 | input.contiguous().data_ptr(), 414 | bbox.contiguous().data_ptr(), 415 | trans.contiguous().data_ptr(), 416 | no_trans, 417 | trans_std, 418 | sample_per_part, 419 | group_size, 420 | part_size, 421 | num_classes, 422 | channels_each_class); 423 | }); 424 | //THCudaCheck(cudaGetLastError()); 425 | return std::make_tuple(input_grad, trans_grad); 426 | } -------------------------------------------------------------------------------- /src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /src/cuda/dcn_v2_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/dcn_v2_im2col_cuda.h" 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | THCState *state = at::globalContext().lazyInitCUDA(); 12 | 13 | // author: Charles Shang 14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 15 | 16 | // [batch gemm] 17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu 18 | 19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b, 20 | float **columns_b, const float **ones_b, 21 | const float **weight_b, const float **bias_b, 22 | float *input, float *output, 23 | float *columns, float *ones, 24 | float *weight, float *bias, 25 | const int input_stride, const int output_stride, 26 | const int columns_stride, const int ones_stride, 27 | const int num_batches) 28 | { 29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 30 | if (idx < num_batches) 31 | { 32 | input_b[idx] = input + idx * input_stride; 33 | output_b[idx] = output + idx * output_stride; 34 | columns_b[idx] = columns + idx * columns_stride; 35 | ones_b[idx] = ones + idx * ones_stride; 36 | // share weights and bias within a Mini-Batch 37 | weight_b[idx] = weight; 38 | bias_b[idx] = bias; 39 | } 40 | } 41 | 42 | at::Tensor 43 | dcn_v2_cuda_forward(const at::Tensor &input, 44 | const at::Tensor &weight, 45 | const at::Tensor &bias, 46 | const at::Tensor &offset, 47 | const at::Tensor &mask, 48 | const int kernel_h, 49 | const int kernel_w, 50 | const int stride_h, 51 | const int stride_w, 52 | const int pad_h, 53 | const int pad_w, 54 | const int dilation_h, 55 | const int dilation_w, 56 | const int deformable_group) 57 | { 58 | using scalar_t = float; 59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); 60 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 61 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); 62 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); 63 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); 64 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor"); 65 | 66 | const int batch = input.size(0); 67 | const int channels = input.size(1); 68 | const int height = input.size(2); 69 | const int width = input.size(3); 70 | 71 | const int channels_out = weight.size(0); 72 | const int channels_kernel = weight.size(1); 73 | const int kernel_h_ = weight.size(2); 74 | const int kernel_w_ = weight.size(3); 75 | 76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); 77 | // printf("Channels: %d %d\n", channels, channels_kernel); 78 | // printf("Channels: %d %d\n", channels_out, channels_kernel); 79 | 80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 82 | 83 | AT_ASSERTM(channels == channels_kernel, 84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 85 | 86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 88 | 89 | auto ones = at::ones({batch, height_out, width_out}, input.options()); 90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 92 | 93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing 94 | // when batch size is large. 95 | // launch batch threads 96 | int matrices_size = batch * sizeof(float *); 97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size)); 98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size)); 99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); 100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); 101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); 102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); 103 | 104 | const int block = 128; 105 | const int grid = (batch + block - 1) / block; 106 | 107 | createBatchGemmBuffer<<>>( 108 | input_b, output_b, 109 | columns_b, ones_b, 110 | weight_b, bias_b, 111 | input.data_ptr(), 112 | output.data_ptr(), 113 | columns.data_ptr(), 114 | ones.data_ptr(), 115 | weight.data_ptr(), 116 | bias.data_ptr(), 117 | channels * width * height, 118 | channels_out * width_out * height_out, 119 | channels * kernel_h * kernel_w * height_out * width_out, 120 | height_out * width_out, 121 | batch); 122 | 123 | long m_ = channels_out; 124 | long n_ = height_out * width_out; 125 | long k_ = 1; 126 | THCudaBlas_SgemmBatched(state, 127 | 't', 128 | 'n', 129 | n_, 130 | m_, 131 | k_, 132 | 1.0f, 133 | ones_b, k_, 134 | bias_b, k_, 135 | 0.0f, 136 | output_b, n_, 137 | batch); 138 | 139 | modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), 140 | input.data_ptr(), 141 | offset.data_ptr(), 142 | mask.data_ptr(), 143 | batch, channels, height, width, 144 | height_out, width_out, kernel_h, kernel_w, 145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 146 | deformable_group, 147 | columns.data_ptr()); 148 | 149 | long m = channels_out; 150 | long n = height_out * width_out; 151 | long k = channels * kernel_h * kernel_w; 152 | THCudaBlas_SgemmBatched(state, 153 | 'n', 154 | 'n', 155 | n, 156 | m, 157 | k, 158 | 1.0f, 159 | (const float **)columns_b, n, 160 | weight_b, k, 161 | 1.0f, 162 | output_b, n, 163 | batch); 164 | 165 | THCudaFree(state, input_b); 166 | THCudaFree(state, output_b); 167 | THCudaFree(state, columns_b); 168 | THCudaFree(state, ones_b); 169 | THCudaFree(state, weight_b); 170 | THCudaFree(state, bias_b); 171 | return output; 172 | } 173 | 174 | __global__ void createBatchGemmBufferBackward( 175 | float **grad_output_b, 176 | float **columns_b, 177 | float **ones_b, 178 | float **weight_b, 179 | float **grad_weight_b, 180 | float **grad_bias_b, 181 | float *grad_output, 182 | float *columns, 183 | float *ones, 184 | float *weight, 185 | float *grad_weight, 186 | float *grad_bias, 187 | const int grad_output_stride, 188 | const int columns_stride, 189 | const int ones_stride, 190 | const int num_batches) 191 | { 192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 193 | if (idx < num_batches) 194 | { 195 | grad_output_b[idx] = grad_output + idx * grad_output_stride; 196 | columns_b[idx] = columns + idx * columns_stride; 197 | ones_b[idx] = ones + idx * ones_stride; 198 | 199 | // share weights and bias within a Mini-Batch 200 | weight_b[idx] = weight; 201 | grad_weight_b[idx] = grad_weight; 202 | grad_bias_b[idx] = grad_bias; 203 | } 204 | } 205 | 206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input, 207 | const at::Tensor &weight, 208 | const at::Tensor &bias, 209 | const at::Tensor &offset, 210 | const at::Tensor &mask, 211 | const at::Tensor &grad_output, 212 | int kernel_h, int kernel_w, 213 | int stride_h, int stride_w, 214 | int pad_h, int pad_w, 215 | int dilation_h, int dilation_w, 216 | int deformable_group) 217 | { 218 | 219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); 220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); 221 | 222 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 223 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); 224 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); 225 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); 226 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor"); 227 | 228 | const int batch = input.size(0); 229 | const int channels = input.size(1); 230 | const int height = input.size(2); 231 | const int width = input.size(3); 232 | 233 | const int channels_out = weight.size(0); 234 | const int channels_kernel = weight.size(1); 235 | const int kernel_h_ = weight.size(2); 236 | const int kernel_w_ = weight.size(3); 237 | 238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 240 | 241 | AT_ASSERTM(channels == channels_kernel, 242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 243 | 244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 246 | 247 | auto ones = at::ones({height_out, width_out}, input.options()); 248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 250 | 251 | auto grad_input = at::zeros_like(input); 252 | auto grad_weight = at::zeros_like(weight); 253 | auto grad_bias = at::zeros_like(bias); 254 | auto grad_offset = at::zeros_like(offset); 255 | auto grad_mask = at::zeros_like(mask); 256 | 257 | using scalar_t = float; 258 | 259 | for (int b = 0; b < batch; b++) 260 | { 261 | auto input_n = input.select(0, b); 262 | auto offset_n = offset.select(0, b); 263 | auto mask_n = mask.select(0, b); 264 | auto grad_output_n = grad_output.select(0, b); 265 | auto grad_input_n = grad_input.select(0, b); 266 | auto grad_offset_n = grad_offset.select(0, b); 267 | auto grad_mask_n = grad_mask.select(0, b); 268 | 269 | long m = channels * kernel_h * kernel_w; 270 | long n = height_out * width_out; 271 | long k = channels_out; 272 | 273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 274 | grad_output_n.data_ptr(), n, 275 | weight.data_ptr(), m, 0.0f, 276 | columns.data_ptr(), n); 277 | 278 | // gradient w.r.t. input coordinate data 279 | modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), 280 | columns.data_ptr(), 281 | input_n.data_ptr(), 282 | offset_n.data_ptr(), 283 | mask_n.data_ptr(), 284 | 1, channels, height, width, 285 | height_out, width_out, kernel_h, kernel_w, 286 | pad_h, pad_w, stride_h, stride_w, 287 | dilation_h, dilation_w, deformable_group, 288 | grad_offset_n.data_ptr(), 289 | grad_mask_n.data_ptr()); 290 | // gradient w.r.t. input data 291 | modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), 292 | columns.data_ptr(), 293 | offset_n.data_ptr(), 294 | mask_n.data_ptr(), 295 | 1, channels, height, width, 296 | height_out, width_out, kernel_h, kernel_w, 297 | pad_h, pad_w, stride_h, stride_w, 298 | dilation_h, dilation_w, deformable_group, 299 | grad_input_n.data_ptr()); 300 | 301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group 302 | modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), 303 | input_n.data_ptr(), 304 | offset_n.data_ptr(), 305 | mask_n.data_ptr(), 306 | 1, channels, height, width, 307 | height_out, width_out, kernel_h, kernel_w, 308 | pad_h, pad_w, stride_h, stride_w, 309 | dilation_h, dilation_w, deformable_group, 310 | columns.data_ptr()); 311 | 312 | long m_ = channels_out; 313 | long n_ = channels * kernel_h * kernel_w; 314 | long k_ = height_out * width_out; 315 | 316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 317 | columns.data_ptr(), k_, 318 | grad_output_n.data_ptr(), k_, 1.0f, 319 | grad_weight.data_ptr(), n_); 320 | 321 | // gradient w.r.t. bias 322 | // long m_ = channels_out; 323 | // long k__ = height_out * width_out; 324 | // THCudaBlas_Sgemm(state, 325 | // 't', 'n', 326 | // k_, m_, 1, 1.0f, 327 | // grad_output_n.data_ptr(), k_, 328 | // ones.data_ptr(), 1, 1.0f, 329 | // grad_bias.data_ptr(), 1); 330 | THCudaBlas_Sgemm(state, 331 | 'N', 'N', 1, m_, k_, 1.0f, 332 | ones.data_ptr(), 1, 333 | grad_output_n.data_ptr(), k_, 334 | 1.0f, 335 | grad_bias.data_ptr(), 1); 336 | } 337 | 338 | return { 339 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias 340 | }; 341 | } 342 | -------------------------------------------------------------------------------- /src/cuda/dcn_v2_im2col_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "dcn_v2_im2col_cuda.h" 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #define CUDA_KERNEL_LOOP(i, n) \ 14 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 15 | i < (n); \ 16 | i += blockDim.x * gridDim.x) 17 | 18 | const int CUDA_NUM_THREADS = 1024; 19 | inline int GET_BLOCKS(const int N) 20 | { 21 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 22 | } 23 | 24 | 25 | __device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width, 26 | const int height, const int width, float h, float w) 27 | { 28 | int h_low = floor(h); 29 | int w_low = floor(w); 30 | int h_high = h_low + 1; 31 | int w_high = w_low + 1; 32 | 33 | float lh = h - h_low; 34 | float lw = w - w_low; 35 | float hh = 1 - lh, hw = 1 - lw; 36 | 37 | float v1 = 0; 38 | if (h_low >= 0 && w_low >= 0) 39 | v1 = bottom_data[h_low * data_width + w_low]; 40 | float v2 = 0; 41 | if (h_low >= 0 && w_high <= width - 1) 42 | v2 = bottom_data[h_low * data_width + w_high]; 43 | float v3 = 0; 44 | if (h_high <= height - 1 && w_low >= 0) 45 | v3 = bottom_data[h_high * data_width + w_low]; 46 | float v4 = 0; 47 | if (h_high <= height - 1 && w_high <= width - 1) 48 | v4 = bottom_data[h_high * data_width + w_high]; 49 | 50 | float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 51 | 52 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 53 | return val; 54 | } 55 | 56 | __device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w, 57 | const int h, const int w, const int height, const int width) 58 | { 59 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 60 | { 61 | //empty 62 | return 0; 63 | } 64 | 65 | int argmax_h_low = floor(argmax_h); 66 | int argmax_w_low = floor(argmax_w); 67 | int argmax_h_high = argmax_h_low + 1; 68 | int argmax_w_high = argmax_w_low + 1; 69 | 70 | float weight = 0; 71 | if (h == argmax_h_low && w == argmax_w_low) 72 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 73 | if (h == argmax_h_low && w == argmax_w_high) 74 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 75 | if (h == argmax_h_high && w == argmax_w_low) 76 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 77 | if (h == argmax_h_high && w == argmax_w_high) 78 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 79 | return weight; 80 | } 81 | 82 | __device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w, 83 | const int height, const int width, const float *im_data, 84 | const int data_width, const int bp_dir) 85 | { 86 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 87 | { 88 | //empty 89 | return 0; 90 | } 91 | 92 | int argmax_h_low = floor(argmax_h); 93 | int argmax_w_low = floor(argmax_w); 94 | int argmax_h_high = argmax_h_low + 1; 95 | int argmax_w_high = argmax_w_low + 1; 96 | 97 | float weight = 0; 98 | 99 | if (bp_dir == 0) 100 | { 101 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 102 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 103 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 104 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 105 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 106 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 107 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 108 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 109 | } 110 | else if (bp_dir == 1) 111 | { 112 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 113 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 114 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 115 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 116 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 117 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 118 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 119 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 120 | } 121 | 122 | return weight; 123 | } 124 | 125 | __global__ void modulated_deformable_im2col_gpu_kernel(const int n, 126 | const float *data_im, const float *data_offset, const float *data_mask, 127 | const int height, const int width, const int kernel_h, const int kernel_w, 128 | const int pad_h, const int pad_w, 129 | const int stride_h, const int stride_w, 130 | const int dilation_h, const int dilation_w, 131 | const int channel_per_deformable_group, 132 | const int batch_size, const int num_channels, const int deformable_group, 133 | const int height_col, const int width_col, 134 | float *data_col) 135 | { 136 | // launch channels * batch_size * height_col * width_col cores 137 | CUDA_KERNEL_LOOP(index, n) 138 | { 139 | // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow) 140 | // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis 141 | 142 | // index index of output matrix 143 | const int w_col = index % width_col; 144 | const int h_col = (index / width_col) % height_col; 145 | // const int b_col = (index / width_col / height_col) % batch_size; 146 | const int b_col = (index / width_col / height_col / num_channels) % batch_size; 147 | // const int c_im = (index / width_col / height_col) / batch_size; 148 | const int c_im = (index / width_col / height_col) % num_channels; 149 | // const int c_col = c_im * kernel_h * kernel_w; 150 | const int c_col = c_im * kernel_h * kernel_w; 151 | 152 | // compute deformable group index 153 | const int deformable_group_index = c_im / channel_per_deformable_group; 154 | 155 | const int h_in = h_col * stride_h - pad_h; 156 | const int w_in = w_col * stride_w - pad_w; 157 | 158 | // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; 159 | float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col; 160 | //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; 161 | const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; 162 | const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 163 | 164 | const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 165 | 166 | for (int i = 0; i < kernel_h; ++i) 167 | { 168 | for (int j = 0; j < kernel_w; ++j) 169 | { 170 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 171 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; 172 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; 173 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 174 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 175 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 176 | float val = static_cast(0); 177 | const float h_im = h_in + i * dilation_h + offset_h; 178 | const float w_im = w_in + j * dilation_w + offset_w; 179 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 180 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) 181 | { 182 | //const float map_h = i * dilation_h + offset_h; 183 | //const float map_w = j * dilation_w + offset_w; 184 | //const int cur_height = height - h_in; 185 | //const int cur_width = width - w_in; 186 | //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 187 | val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im); 188 | } 189 | *data_col_ptr = val * mask; 190 | // data_col_ptr += batch_size * height_col * width_col; 191 | data_col_ptr += height_col * width_col; 192 | } 193 | } 194 | } 195 | } 196 | 197 | __global__ void modulated_deformable_col2im_gpu_kernel(const int n, 198 | const float *data_col, const float *data_offset, const float *data_mask, 199 | const int channels, const int height, const int width, 200 | const int kernel_h, const int kernel_w, 201 | const int pad_h, const int pad_w, 202 | const int stride_h, const int stride_w, 203 | const int dilation_h, const int dilation_w, 204 | const int channel_per_deformable_group, 205 | const int batch_size, const int deformable_group, 206 | const int height_col, const int width_col, 207 | float *grad_im) 208 | { 209 | CUDA_KERNEL_LOOP(index, n) 210 | { 211 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 212 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 213 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 214 | // compute the start and end of the output 215 | 216 | const int deformable_group_index = c / channel_per_deformable_group; 217 | 218 | int w_out = index % width_col; 219 | int h_out = (index / width_col) % height_col; 220 | int b = (index / width_col / height_col) % batch_size; 221 | int w_in = w_out * stride_w - pad_w; 222 | int h_in = h_out * stride_h - pad_h; 223 | 224 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 225 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 226 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 227 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 228 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; 229 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 230 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 231 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 232 | const float cur_inv_h_data = h_in + i * dilation_h + offset_h; 233 | const float cur_inv_w_data = w_in + j * dilation_w + offset_w; 234 | 235 | const float cur_top_grad = data_col[index] * mask; 236 | const int cur_h = (int)cur_inv_h_data; 237 | const int cur_w = (int)cur_inv_w_data; 238 | for (int dy = -2; dy <= 2; dy++) 239 | { 240 | for (int dx = -2; dx <= 2; dx++) 241 | { 242 | if (cur_h + dy >= 0 && cur_h + dy < height && 243 | cur_w + dx >= 0 && cur_w + dx < width && 244 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 245 | abs(cur_inv_w_data - (cur_w + dx)) < 1) 246 | { 247 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 248 | float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 249 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 250 | } 251 | } 252 | } 253 | } 254 | } 255 | 256 | __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, 257 | const float *data_col, const float *data_im, 258 | const float *data_offset, const float *data_mask, 259 | const int channels, const int height, const int width, 260 | const int kernel_h, const int kernel_w, 261 | const int pad_h, const int pad_w, 262 | const int stride_h, const int stride_w, 263 | const int dilation_h, const int dilation_w, 264 | const int channel_per_deformable_group, 265 | const int batch_size, const int offset_channels, const int deformable_group, 266 | const int height_col, const int width_col, 267 | float *grad_offset, float *grad_mask) 268 | { 269 | CUDA_KERNEL_LOOP(index, n) 270 | { 271 | float val = 0, mval = 0; 272 | int w = index % width_col; 273 | int h = (index / width_col) % height_col; 274 | int c = (index / width_col / height_col) % offset_channels; 275 | int b = (index / width_col / height_col) / offset_channels; 276 | // compute the start and end of the output 277 | 278 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 279 | const int col_step = kernel_h * kernel_w; 280 | int cnt = 0; 281 | const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; 282 | const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; 283 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 284 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 285 | 286 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 287 | 288 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) 289 | { 290 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 291 | const int bp_dir = offset_c % 2; 292 | 293 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 294 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 295 | int w_out = col_pos % width_col; 296 | int h_out = (col_pos / width_col) % height_col; 297 | int w_in = w_out * stride_w - pad_w; 298 | int h_in = h_out * stride_h - pad_h; 299 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 300 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 301 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); 302 | const float offset_h = data_offset_ptr[data_offset_h_ptr]; 303 | const float offset_w = data_offset_ptr[data_offset_w_ptr]; 304 | const float mask = data_mask_ptr[data_mask_hw_ptr]; 305 | float inv_h = h_in + i * dilation_h + offset_h; 306 | float inv_w = w_in + j * dilation_w + offset_w; 307 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) 308 | { 309 | inv_h = inv_w = -2; 310 | } 311 | else 312 | { 313 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); 314 | } 315 | const float weight = dmcn_get_coordinate_weight_cuda( 316 | inv_h, inv_w, 317 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 318 | val += weight * data_col_ptr[col_pos] * mask; 319 | cnt += 1; 320 | } 321 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val); 322 | grad_offset[index] = val; 323 | if (offset_c % 2 == 0) 324 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); 325 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; 326 | } 327 | } 328 | 329 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 330 | const float* data_im, const float* data_offset, const float* data_mask, 331 | const int batch_size, const int channels, const int height_im, const int width_im, 332 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 333 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 334 | const int dilation_h, const int dilation_w, 335 | const int deformable_group, float* data_col) { 336 | // num_axes should be smaller than block size 337 | const int channel_per_deformable_group = channels / deformable_group; 338 | const int num_kernels = channels * batch_size * height_col * width_col; 339 | modulated_deformable_im2col_gpu_kernel 340 | <<>>( 342 | num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, 343 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, 344 | batch_size, channels, deformable_group, height_col, width_col, data_col); 345 | 346 | cudaError_t err = cudaGetLastError(); 347 | if (err != cudaSuccess) 348 | { 349 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); 350 | } 351 | 352 | } 353 | 354 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 355 | const float* data_col, const float* data_offset, const float* data_mask, 356 | const int batch_size, const int channels, const int height_im, const int width_im, 357 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 358 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 359 | const int dilation_h, const int dilation_w, 360 | const int deformable_group, float* grad_im){ 361 | 362 | const int channel_per_deformable_group = channels / deformable_group; 363 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; 364 | modulated_deformable_col2im_gpu_kernel 365 | <<>>( 367 | num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, 368 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, 369 | dilation_h, dilation_w, channel_per_deformable_group, 370 | batch_size, deformable_group, height_col, width_col, grad_im); 371 | cudaError_t err = cudaGetLastError(); 372 | if (err != cudaSuccess) 373 | { 374 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); 375 | } 376 | 377 | } 378 | 379 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 380 | const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, 381 | const int batch_size, const int channels, const int height_im, const int width_im, 382 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 383 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 384 | const int dilation_h, const int dilation_w, 385 | const int deformable_group, 386 | float* grad_offset, float* grad_mask) { 387 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; 388 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; 389 | modulated_deformable_col2im_coord_gpu_kernel 390 | <<>>( 392 | num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, 393 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 394 | dilation_h, dilation_w, channel_per_deformable_group, 395 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, 396 | grad_offset, grad_mask); 397 | cudaError_t err = cudaGetLastError(); 398 | if (err != cudaSuccess) 399 | { 400 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); 401 | } 402 | } -------------------------------------------------------------------------------- /src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /src/cuda/dcn_v2_psroi_pooling_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #define CUDA_KERNEL_LOOP(i, n) \ 23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 24 | i < (n); \ 25 | i += blockDim.x * gridDim.x) 26 | 27 | const int CUDA_NUM_THREADS = 1024; 28 | inline int GET_BLOCKS(const int N) 29 | { 30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 31 | } 32 | 33 | template 34 | __device__ T bilinear_interp_cuda( 35 | const T *data, 36 | const T x, 37 | const T y, 38 | const int width, 39 | const int height) 40 | { 41 | int x1 = floor(x); 42 | int x2 = ceil(x); 43 | int y1 = floor(y); 44 | int y2 = ceil(y); 45 | T dist_x = static_cast(x - x1); 46 | T dist_y = static_cast(y - y1); 47 | T value11 = data[y1 * width + x1]; 48 | T value12 = data[y2 * width + x1]; 49 | T value21 = data[y1 * width + x2]; 50 | T value22 = data[y2 * width + x2]; 51 | T value = (1 - dist_x) * (1 - dist_y) * value11 + 52 | (1 - dist_x) * dist_y * value12 + 53 | dist_x * (1 - dist_y) * value21 + 54 | dist_x * dist_y * value22; 55 | return value; 56 | } 57 | 58 | template 59 | __global__ void DeformablePSROIPoolForwardKernelCuda( 60 | const int count, 61 | const T *bottom_data, 62 | const T spatial_scale, 63 | const int channels, 64 | const int height, const int width, 65 | const int pooled_height, const int pooled_width, 66 | const T *bottom_rois, const T *bottom_trans, 67 | const int no_trans, 68 | const T trans_std, 69 | const int sample_per_part, 70 | const int output_dim, 71 | const int group_size, 72 | const int part_size, 73 | const int num_classes, 74 | const int channels_each_class, 75 | T *top_data, 76 | T *top_count) 77 | { 78 | CUDA_KERNEL_LOOP(index, count) 79 | { 80 | // The output is in order (n, ctop, ph, pw) 81 | int pw = index % pooled_width; 82 | int ph = (index / pooled_width) % pooled_height; 83 | int ctop = (index / pooled_width / pooled_height) % output_dim; 84 | int n = index / pooled_width / pooled_height / output_dim; 85 | 86 | // [start, end) interval for spatial sampling 87 | const T *offset_bottom_rois = bottom_rois + n * 5; 88 | int roi_batch_ind = offset_bottom_rois[0]; 89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 93 | 94 | // Force too small ROIs to be 1x1 95 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 96 | T roi_height = max(roi_end_h - roi_start_h, 0.1); 97 | 98 | // Compute w and h at bottom 99 | T bin_size_h = roi_height / static_cast(pooled_height); 100 | T bin_size_w = roi_width / static_cast(pooled_width); 101 | 102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 104 | 105 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 106 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 107 | int class_id = ctop / channels_each_class; 108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 110 | 111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 112 | wstart += trans_x * roi_width; 113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 114 | hstart += trans_y * roi_height; 115 | 116 | T sum = 0; 117 | int count = 0; 118 | int gw = floor(static_cast(pw) * group_size / pooled_width); 119 | int gh = floor(static_cast(ph) * group_size / pooled_height); 120 | gw = min(max(gw, 0), group_size - 1); 121 | gh = min(max(gh, 0), group_size - 1); 122 | 123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 124 | for (int ih = 0; ih < sample_per_part; ih++) 125 | { 126 | for (int iw = 0; iw < sample_per_part; iw++) 127 | { 128 | T w = wstart + iw * sub_bin_size_w; 129 | T h = hstart + ih * sub_bin_size_h; 130 | // bilinear interpolation 131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 132 | { 133 | continue; 134 | } 135 | w = min(max(w, 0.), width - 1.); 136 | h = min(max(h, 0.), height - 1.); 137 | int c = (ctop * group_size + gh) * group_size + gw; 138 | T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height); 139 | sum += val; 140 | count++; 141 | } 142 | } 143 | top_data[index] = count == 0 ? static_cast(0) : sum / count; 144 | top_count[index] = count; 145 | } 146 | } 147 | 148 | template 149 | __global__ void DeformablePSROIPoolBackwardAccKernelCuda( 150 | const int count, 151 | const T *top_diff, 152 | const T *top_count, 153 | const int num_rois, 154 | const T spatial_scale, 155 | const int channels, 156 | const int height, const int width, 157 | const int pooled_height, const int pooled_width, 158 | const int output_dim, 159 | T *bottom_data_diff, T *bottom_trans_diff, 160 | const T *bottom_data, 161 | const T *bottom_rois, 162 | const T *bottom_trans, 163 | const int no_trans, 164 | const T trans_std, 165 | const int sample_per_part, 166 | const int group_size, 167 | const int part_size, 168 | const int num_classes, 169 | const int channels_each_class) 170 | { 171 | CUDA_KERNEL_LOOP(index, count) 172 | { 173 | // The output is in order (n, ctop, ph, pw) 174 | int pw = index % pooled_width; 175 | int ph = (index / pooled_width) % pooled_height; 176 | int ctop = (index / pooled_width / pooled_height) % output_dim; 177 | int n = index / pooled_width / pooled_height / output_dim; 178 | 179 | // [start, end) interval for spatial sampling 180 | const T *offset_bottom_rois = bottom_rois + n * 5; 181 | int roi_batch_ind = offset_bottom_rois[0]; 182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 186 | 187 | // Force too small ROIs to be 1x1 188 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 189 | T roi_height = max(roi_end_h - roi_start_h, 0.1); 190 | 191 | // Compute w and h at bottom 192 | T bin_size_h = roi_height / static_cast(pooled_height); 193 | T bin_size_w = roi_width / static_cast(pooled_width); 194 | 195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 197 | 198 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 199 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 200 | int class_id = ctop / channels_each_class; 201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 203 | 204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 205 | wstart += trans_x * roi_width; 206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 207 | hstart += trans_y * roi_height; 208 | 209 | if (top_count[index] <= 0) 210 | { 211 | continue; 212 | } 213 | T diff_val = top_diff[index] / top_count[index]; 214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 216 | int gw = floor(static_cast(pw) * group_size / pooled_width); 217 | int gh = floor(static_cast(ph) * group_size / pooled_height); 218 | gw = min(max(gw, 0), group_size - 1); 219 | gh = min(max(gh, 0), group_size - 1); 220 | 221 | for (int ih = 0; ih < sample_per_part; ih++) 222 | { 223 | for (int iw = 0; iw < sample_per_part; iw++) 224 | { 225 | T w = wstart + iw * sub_bin_size_w; 226 | T h = hstart + ih * sub_bin_size_h; 227 | // bilinear interpolation 228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 229 | { 230 | continue; 231 | } 232 | w = min(max(w, 0.), width - 1.); 233 | h = min(max(h, 0.), height - 1.); 234 | int c = (ctop * group_size + gh) * group_size + gw; 235 | // backward on feature 236 | int x0 = floor(w); 237 | int x1 = ceil(w); 238 | int y0 = floor(h); 239 | int y1 = ceil(h); 240 | T dist_x = w - x0, dist_y = h - y0; 241 | T q00 = (1 - dist_x) * (1 - dist_y); 242 | T q01 = (1 - dist_x) * dist_y; 243 | T q10 = dist_x * (1 - dist_y); 244 | T q11 = dist_x * dist_y; 245 | int bottom_index_base = c * height * width; 246 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); 250 | 251 | if (no_trans) 252 | { 253 | continue; 254 | } 255 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 256 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 257 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 258 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 259 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; 260 | diff_x *= roi_width; 261 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; 262 | diff_y *= roi_height; 263 | 264 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); 265 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); 266 | } 267 | } 268 | } 269 | } 270 | 271 | std::tuple 272 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 273 | const at::Tensor &bbox, 274 | const at::Tensor &trans, 275 | const int no_trans, 276 | const float spatial_scale, 277 | const int output_dim, 278 | const int group_size, 279 | const int pooled_size, 280 | const int part_size, 281 | const int sample_per_part, 282 | const float trans_std) 283 | { 284 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 285 | AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor"); 286 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); 287 | 288 | const int batch = input.size(0); 289 | const int channels = input.size(1); 290 | const int height = input.size(2); 291 | const int width = input.size(3); 292 | const int channels_trans = no_trans ? 2 : trans.size(1); 293 | const int num_bbox = bbox.size(0); 294 | 295 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 296 | auto pooled_height = pooled_size; 297 | auto pooled_width = pooled_size; 298 | 299 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 300 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 301 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 302 | 303 | const int num_classes = no_trans ? 1 : channels_trans / 2; 304 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 305 | 306 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 307 | 308 | if (out.numel() == 0) 309 | { 310 | THCudaCheck(cudaGetLastError()); 311 | return std::make_tuple(out, top_count); 312 | } 313 | 314 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 315 | dim3 block(512); 316 | 317 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { 318 | DeformablePSROIPoolForwardKernelCuda<<>>( 319 | out_size, 320 | input.contiguous().data_ptr(), 321 | spatial_scale, 322 | channels, 323 | height, width, 324 | pooled_height, 325 | pooled_width, 326 | bbox.contiguous().data_ptr(), 327 | trans.contiguous().data_ptr(), 328 | no_trans, 329 | trans_std, 330 | sample_per_part, 331 | output_dim, 332 | group_size, 333 | part_size, 334 | num_classes, 335 | channels_each_class, 336 | out.data_ptr(), 337 | top_count.data_ptr()); 338 | }); 339 | THCudaCheck(cudaGetLastError()); 340 | return std::make_tuple(out, top_count); 341 | } 342 | 343 | std::tuple 344 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 345 | const at::Tensor &input, 346 | const at::Tensor &bbox, 347 | const at::Tensor &trans, 348 | const at::Tensor &top_count, 349 | const int no_trans, 350 | const float spatial_scale, 351 | const int output_dim, 352 | const int group_size, 353 | const int pooled_size, 354 | const int part_size, 355 | const int sample_per_part, 356 | const float trans_std) 357 | { 358 | AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor"); 359 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); 360 | AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor"); 361 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); 362 | AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor"); 363 | 364 | const int batch = input.size(0); 365 | const int channels = input.size(1); 366 | const int height = input.size(2); 367 | const int width = input.size(3); 368 | const int channels_trans = no_trans ? 2 : trans.size(1); 369 | const int num_bbox = bbox.size(0); 370 | 371 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 372 | auto pooled_height = pooled_size; 373 | auto pooled_width = pooled_size; 374 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 375 | const int num_classes = no_trans ? 1 : channels_trans / 2; 376 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 377 | 378 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); 379 | auto trans_grad = at::zeros_like(trans); 380 | 381 | if (input_grad.numel() == 0) 382 | { 383 | THCudaCheck(cudaGetLastError()); 384 | return std::make_tuple(input_grad, trans_grad); 385 | } 386 | 387 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 388 | dim3 block(512); 389 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 390 | 391 | AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { 392 | DeformablePSROIPoolBackwardAccKernelCuda<<>>( 393 | out_size, 394 | out_grad.contiguous().data_ptr(), 395 | top_count.contiguous().data_ptr(), 396 | num_bbox, 397 | spatial_scale, 398 | channels, 399 | height, 400 | width, 401 | pooled_height, 402 | pooled_width, 403 | output_dim, 404 | input_grad.contiguous().data_ptr(), 405 | trans_grad.contiguous().data_ptr(), 406 | input.contiguous().data_ptr(), 407 | bbox.contiguous().data_ptr(), 408 | trans.contiguous().data_ptr(), 409 | no_trans, 410 | trans_std, 411 | sample_per_part, 412 | group_size, 413 | part_size, 414 | num_classes, 415 | channels_each_class); 416 | }); 417 | THCudaCheck(cudaGetLastError()); 418 | return std::make_tuple(input_grad, trans_grad); 419 | } -------------------------------------------------------------------------------- /src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /src/dcn_v2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | at::Tensor 10 | dcn_v2_forward(const at::Tensor &input, 11 | const at::Tensor &weight, 12 | const at::Tensor &bias, 13 | const at::Tensor &offset, 14 | const at::Tensor &mask, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int deformable_group) 24 | { 25 | if (input.is_cuda()) 26 | { 27 | #ifdef WITH_CUDA 28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask, 29 | kernel_h, kernel_w, 30 | stride_h, stride_w, 31 | pad_h, pad_w, 32 | dilation_h, dilation_w, 33 | deformable_group); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | else{ 39 | return dcn_v2_cpu_forward(input, weight, bias, offset, mask, 40 | kernel_h, kernel_w, 41 | stride_h, stride_w, 42 | pad_h, pad_w, 43 | dilation_h, dilation_w, 44 | deformable_group); 45 | } 46 | } 47 | 48 | std::vector 49 | dcn_v2_backward(const at::Tensor &input, 50 | const at::Tensor &weight, 51 | const at::Tensor &bias, 52 | const at::Tensor &offset, 53 | const at::Tensor &mask, 54 | const at::Tensor &grad_output, 55 | int kernel_h, int kernel_w, 56 | int stride_h, int stride_w, 57 | int pad_h, int pad_w, 58 | int dilation_h, int dilation_w, 59 | int deformable_group) 60 | { 61 | if (input.is_cuda()) 62 | { 63 | #ifdef WITH_CUDA 64 | return dcn_v2_cuda_backward(input, 65 | weight, 66 | bias, 67 | offset, 68 | mask, 69 | grad_output, 70 | kernel_h, kernel_w, 71 | stride_h, stride_w, 72 | pad_h, pad_w, 73 | dilation_h, dilation_w, 74 | deformable_group); 75 | #else 76 | AT_ERROR("Not compiled with GPU support"); 77 | #endif 78 | } 79 | else{ 80 | return dcn_v2_cpu_backward(input, 81 | weight, 82 | bias, 83 | offset, 84 | mask, 85 | grad_output, 86 | kernel_h, kernel_w, 87 | stride_h, stride_w, 88 | pad_h, pad_w, 89 | dilation_h, dilation_w, 90 | deformable_group); 91 | } 92 | } 93 | 94 | std::tuple 95 | dcn_v2_psroi_pooling_forward(const at::Tensor &input, 96 | const at::Tensor &bbox, 97 | const at::Tensor &trans, 98 | const int no_trans, 99 | const float spatial_scale, 100 | const int output_dim, 101 | const int group_size, 102 | const int pooled_size, 103 | const int part_size, 104 | const int sample_per_part, 105 | const float trans_std) 106 | { 107 | if (input.is_cuda()) 108 | { 109 | #ifdef WITH_CUDA 110 | return dcn_v2_psroi_pooling_cuda_forward(input, 111 | bbox, 112 | trans, 113 | no_trans, 114 | spatial_scale, 115 | output_dim, 116 | group_size, 117 | pooled_size, 118 | part_size, 119 | sample_per_part, 120 | trans_std); 121 | #else 122 | AT_ERROR("Not compiled with GPU support"); 123 | #endif 124 | } 125 | else{ 126 | return dcn_v2_psroi_pooling_cpu_forward(input, 127 | bbox, 128 | trans, 129 | no_trans, 130 | spatial_scale, 131 | output_dim, 132 | group_size, 133 | pooled_size, 134 | part_size, 135 | sample_per_part, 136 | trans_std); 137 | } 138 | } 139 | 140 | std::tuple 141 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, 142 | const at::Tensor &input, 143 | const at::Tensor &bbox, 144 | const at::Tensor &trans, 145 | const at::Tensor &top_count, 146 | const int no_trans, 147 | const float spatial_scale, 148 | const int output_dim, 149 | const int group_size, 150 | const int pooled_size, 151 | const int part_size, 152 | const int sample_per_part, 153 | const float trans_std) 154 | { 155 | if (input.is_cuda()) 156 | { 157 | #ifdef WITH_CUDA 158 | return dcn_v2_psroi_pooling_cuda_backward(out_grad, 159 | input, 160 | bbox, 161 | trans, 162 | top_count, 163 | no_trans, 164 | spatial_scale, 165 | output_dim, 166 | group_size, 167 | pooled_size, 168 | part_size, 169 | sample_per_part, 170 | trans_std); 171 | #else 172 | AT_ERROR("Not compiled with GPU support"); 173 | #endif 174 | } 175 | else{ 176 | return dcn_v2_psroi_pooling_cpu_backward(out_grad, 177 | input, 178 | bbox, 179 | trans, 180 | top_count, 181 | no_trans, 182 | spatial_scale, 183 | output_dim, 184 | group_size, 185 | pooled_size, 186 | part_size, 187 | sample_per_part, 188 | trans_std); 189 | } 190 | } -------------------------------------------------------------------------------- /src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import gradcheck 7 | 8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling 9 | 10 | deformable_groups = 1 11 | N, inC, inH, inW = 2, 2, 4, 4 12 | outC = 2 13 | kH, kW = 3, 3 14 | 15 | 16 | def conv_identify(weight, bias): 17 | weight.data.zero_() 18 | bias.data.zero_() 19 | o, i, h, w = weight.shape 20 | y = h // 2 21 | x = w // 2 22 | for p in range(i): 23 | for q in range(o): 24 | if p == q: 25 | weight.data[q, p, y, x] = 1.0 26 | 27 | 28 | def check_zero_offset(): 29 | conv_offset = nn.Conv2d( 30 | inC, 31 | deformable_groups * 2 * kH * kW, 32 | kernel_size=(kH, kW), 33 | stride=(1, 1), 34 | padding=(1, 1), 35 | bias=True, 36 | ).cuda() 37 | 38 | conv_mask = nn.Conv2d( 39 | inC, 40 | deformable_groups * 1 * kH * kW, 41 | kernel_size=(kH, kW), 42 | stride=(1, 1), 43 | padding=(1, 1), 44 | bias=True, 45 | ).cuda() 46 | 47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() 48 | 49 | conv_offset.weight.data.zero_() 50 | conv_offset.bias.data.zero_() 51 | conv_mask.weight.data.zero_() 52 | conv_mask.bias.data.zero_() 53 | conv_identify(dcn_v2.weight, dcn_v2.bias) 54 | 55 | input = torch.randn(N, inC, inH, inW).cuda() 56 | offset = conv_offset(input) 57 | mask = conv_mask(input) 58 | mask = torch.sigmoid(mask) 59 | output = dcn_v2(input, offset, mask) 60 | output *= 2 61 | d = (input - output).abs().max() 62 | if d < 1e-10: 63 | print("Zero offset passed") 64 | else: 65 | print("Zero offset failed") 66 | print(input) 67 | print(output) 68 | 69 | 70 | def check_gradient_dconv(): 71 | 72 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01 73 | input.requires_grad = True 74 | 75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 76 | # offset.data.zero_() 77 | # offset.data -= 0.5 78 | offset.requires_grad = True 79 | 80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() 81 | # mask.data.zero_() 82 | mask.requires_grad = True 83 | mask = torch.sigmoid(mask) 84 | 85 | weight = torch.randn(outC, inC, kH, kW).cuda() 86 | weight.requires_grad = True 87 | 88 | bias = torch.rand(outC).cuda() 89 | bias.requires_grad = True 90 | 91 | stride = 1 92 | padding = 1 93 | dilation = 1 94 | 95 | print( 96 | "check_gradient_dconv: ", 97 | gradcheck( 98 | dcn_v2_conv, 99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), 100 | eps=1e-3, 101 | atol=1e-4, 102 | rtol=1e-2, 103 | ), 104 | ) 105 | 106 | 107 | def check_pooling_zero_offset(): 108 | 109 | input = torch.randn(2, 16, 64, 64).cuda().zero_() 110 | input[0, :, 16:26, 16:26] = 1.0 111 | input[1, :, 10:20, 20:30] = 2.0 112 | rois = ( 113 | torch.tensor( 114 | [ 115 | [0, 65, 65, 103, 103], 116 | [1, 81, 41, 119, 79], 117 | ] 118 | ) 119 | .cuda() 120 | .float() 121 | ) 122 | pooling = DCNv2Pooling( 123 | spatial_scale=1.0 / 4, 124 | pooled_size=7, 125 | output_dim=16, 126 | no_trans=True, 127 | group_size=1, 128 | trans_std=0.0, 129 | ).cuda() 130 | 131 | out = pooling(input, rois, input.new()) 132 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 133 | print(s) 134 | 135 | dpooling = DCNv2Pooling( 136 | spatial_scale=1.0 / 4, 137 | pooled_size=7, 138 | output_dim=16, 139 | no_trans=False, 140 | group_size=1, 141 | trans_std=0.0, 142 | ).cuda() 143 | offset = torch.randn(20, 2, 7, 7).cuda().zero_() 144 | dout = dpooling(input, rois, offset) 145 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 146 | print(s) 147 | 148 | 149 | def check_gradient_dpooling(): 150 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01 151 | N = 4 152 | batch_inds = torch.randint(2, (N, 1)).cuda().float() 153 | x = torch.rand((N, 1)).cuda().float() * 15 154 | y = torch.rand((N, 1)).cuda().float() * 15 155 | w = torch.rand((N, 1)).cuda().float() * 10 156 | h = torch.rand((N, 1)).cuda().float() * 10 157 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 158 | offset = torch.randn(N, 2, 3, 3).cuda() 159 | input.requires_grad = True 160 | offset.requires_grad = True 161 | 162 | spatial_scale = 1.0 / 4 163 | pooled_size = 3 164 | output_dim = 3 165 | no_trans = 0 166 | group_size = 1 167 | trans_std = 0.0 168 | sample_per_part = 4 169 | part_size = pooled_size 170 | 171 | print( 172 | "check_gradient_dpooling:", 173 | gradcheck( 174 | dcn_v2_pooling, 175 | ( 176 | input, 177 | rois, 178 | offset, 179 | spatial_scale, 180 | pooled_size, 181 | output_dim, 182 | no_trans, 183 | group_size, 184 | part_size, 185 | sample_per_part, 186 | trans_std, 187 | ), 188 | eps=1e-4, 189 | ), 190 | ) 191 | 192 | 193 | def example_dconv(): 194 | input = torch.randn(2, 64, 128, 128).cuda() 195 | # wrap all things (offset and mask) in DCN 196 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() 197 | # print(dcn.weight.shape, input.shape) 198 | output = dcn(input) 199 | targert = output.new(*output.size()) 200 | targert.data.uniform_(-0.01, 0.01) 201 | error = (targert - output).mean() 202 | error.backward() 203 | print(output.shape) 204 | 205 | 206 | def example_dpooling(): 207 | input = torch.randn(2, 32, 64, 64).cuda() 208 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 209 | x = torch.randint(256, (20, 1)).cuda().float() 210 | y = torch.randint(256, (20, 1)).cuda().float() 211 | w = torch.randint(64, (20, 1)).cuda().float() 212 | h = torch.randint(64, (20, 1)).cuda().float() 213 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 214 | offset = torch.randn(20, 2, 7, 7).cuda() 215 | input.requires_grad = True 216 | offset.requires_grad = True 217 | 218 | # normal roi_align 219 | pooling = DCNv2Pooling( 220 | spatial_scale=1.0 / 4, 221 | pooled_size=7, 222 | output_dim=32, 223 | no_trans=True, 224 | group_size=1, 225 | trans_std=0.1, 226 | ).cuda() 227 | 228 | # deformable pooling 229 | dpooling = DCNv2Pooling( 230 | spatial_scale=1.0 / 4, 231 | pooled_size=7, 232 | output_dim=32, 233 | no_trans=False, 234 | group_size=1, 235 | trans_std=0.1, 236 | ).cuda() 237 | 238 | out = pooling(input, rois, offset) 239 | dout = dpooling(input, rois, offset) 240 | print(out.shape) 241 | print(dout.shape) 242 | 243 | target_out = out.new(*out.size()) 244 | target_out.data.uniform_(-0.01, 0.01) 245 | target_dout = dout.new(*dout.size()) 246 | target_dout.data.uniform_(-0.01, 0.01) 247 | e = (target_out - out).mean() 248 | e.backward() 249 | e = (target_dout - dout).mean() 250 | e.backward() 251 | 252 | 253 | def example_mdpooling(): 254 | input = torch.randn(2, 32, 64, 64).cuda() 255 | input.requires_grad = True 256 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 257 | x = torch.randint(256, (20, 1)).cuda().float() 258 | y = torch.randint(256, (20, 1)).cuda().float() 259 | w = torch.randint(64, (20, 1)).cuda().float() 260 | h = torch.randint(64, (20, 1)).cuda().float() 261 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 262 | 263 | # mdformable pooling (V2) 264 | dpooling = DCNPooling( 265 | spatial_scale=1.0 / 4, 266 | pooled_size=7, 267 | output_dim=32, 268 | no_trans=False, 269 | group_size=1, 270 | trans_std=0.1, 271 | deform_fc_dim=1024, 272 | ).cuda() 273 | 274 | dout = dpooling(input, rois) 275 | target = dout.new(*dout.size()) 276 | target.data.uniform_(-0.1, 0.1) 277 | error = (target - dout).mean() 278 | error.backward() 279 | print(dout.shape) 280 | 281 | 282 | if __name__ == "__main__": 283 | 284 | example_dconv() 285 | example_dpooling() 286 | example_mdpooling() 287 | 288 | check_pooling_zero_offset() 289 | # zero offset check 290 | if inC == outC: 291 | check_zero_offset() 292 | 293 | check_gradient_dpooling() 294 | check_gradient_dconv() 295 | # """ 296 | # ****** Note: backward is not reentrant error may not be a serious problem, 297 | # ****** since the max error is less than 1e-7, 298 | # ****** Still looking for what trigger this problem 299 | # """ 300 | -------------------------------------------------------------------------------- /test/testcpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import gradcheck 7 | 8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling 9 | 10 | deformable_groups = 1 11 | N, inC, inH, inW = 2, 2, 4, 4 12 | outC = 2 13 | kH, kW = 3, 3 14 | 15 | 16 | def conv_identify(weight, bias): 17 | weight.data.zero_() 18 | bias.data.zero_() 19 | o, i, h, w = weight.shape 20 | y = h // 2 21 | x = w // 2 22 | for p in range(i): 23 | for q in range(o): 24 | if p == q: 25 | weight.data[q, p, y, x] = 1.0 26 | 27 | 28 | def check_zero_offset(): 29 | conv_offset = nn.Conv2d( 30 | inC, 31 | deformable_groups * 2 * kH * kW, 32 | kernel_size=(kH, kW), 33 | stride=(1, 1), 34 | padding=(1, 1), 35 | bias=True, 36 | ) 37 | 38 | conv_mask = nn.Conv2d( 39 | inC, 40 | deformable_groups * 1 * kH * kW, 41 | kernel_size=(kH, kW), 42 | stride=(1, 1), 43 | padding=(1, 1), 44 | bias=True, 45 | ) 46 | 47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) 48 | 49 | conv_offset.weight.data.zero_() 50 | conv_offset.bias.data.zero_() 51 | conv_mask.weight.data.zero_() 52 | conv_mask.bias.data.zero_() 53 | conv_identify(dcn_v2.weight, dcn_v2.bias) 54 | 55 | input = torch.randn(N, inC, inH, inW) 56 | offset = conv_offset(input) 57 | mask = conv_mask(input) 58 | mask = torch.sigmoid(mask) 59 | output = dcn_v2(input, offset, mask) 60 | output *= 2 61 | d = (input - output).abs().max() 62 | if d < 1e-10: 63 | print("Zero offset passed") 64 | else: 65 | print("Zero offset failed") 66 | print(input) 67 | print(output) 68 | 69 | 70 | def check_gradient_dconv(): 71 | 72 | input = torch.rand(N, inC, inH, inW) * 0.01 73 | input.requires_grad = True 74 | 75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW) * 2 76 | # offset.data.zero_() 77 | # offset.data -= 0.5 78 | offset.requires_grad = True 79 | 80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW) 81 | # mask.data.zero_() 82 | mask.requires_grad = True 83 | mask = torch.sigmoid(mask) 84 | 85 | weight = torch.randn(outC, inC, kH, kW) 86 | weight.requires_grad = True 87 | 88 | bias = torch.rand(outC) 89 | bias.requires_grad = True 90 | 91 | stride = 1 92 | padding = 1 93 | dilation = 1 94 | 95 | print( 96 | "check_gradient_dconv: ", 97 | gradcheck( 98 | dcn_v2_conv, 99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), 100 | eps=1e-3, 101 | atol=1e-4, 102 | rtol=1e-2, 103 | ), 104 | ) 105 | 106 | 107 | def check_pooling_zero_offset(): 108 | 109 | input = torch.randn(2, 16, 64, 64).zero_() 110 | input[0, :, 16:26, 16:26] = 1.0 111 | input[1, :, 10:20, 20:30] = 2.0 112 | rois = torch.tensor( 113 | [ 114 | [0, 65, 65, 103, 103], 115 | [1, 81, 41, 119, 79], 116 | ] 117 | ).float() 118 | pooling = DCNv2Pooling( 119 | spatial_scale=1.0 / 4, 120 | pooled_size=7, 121 | output_dim=16, 122 | no_trans=True, 123 | group_size=1, 124 | trans_std=0.0, 125 | ) 126 | 127 | out = pooling(input, rois, input.new()) 128 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 129 | print(s) 130 | 131 | dpooling = DCNv2Pooling( 132 | spatial_scale=1.0 / 4, 133 | pooled_size=7, 134 | output_dim=16, 135 | no_trans=False, 136 | group_size=1, 137 | trans_std=0.0, 138 | ) 139 | offset = torch.randn(20, 2, 7, 7).zero_() 140 | dout = dpooling(input, rois, offset) 141 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 142 | print(s) 143 | 144 | 145 | def check_gradient_dpooling(): 146 | input = torch.randn(2, 3, 5, 5) * 0.01 147 | N = 4 148 | batch_inds = torch.randint(2, (N, 1)).float() 149 | x = torch.rand((N, 1)).float() * 15 150 | y = torch.rand((N, 1)).float() * 15 151 | w = torch.rand((N, 1)).float() * 10 152 | h = torch.rand((N, 1)).float() * 10 153 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 154 | offset = torch.randn(N, 2, 3, 3) 155 | input.requires_grad = True 156 | offset.requires_grad = True 157 | 158 | spatial_scale = 1.0 / 4 159 | pooled_size = 3 160 | output_dim = 3 161 | no_trans = 0 162 | group_size = 1 163 | trans_std = 0.0 164 | sample_per_part = 4 165 | part_size = pooled_size 166 | 167 | print( 168 | "check_gradient_dpooling:", 169 | gradcheck( 170 | dcn_v2_pooling, 171 | ( 172 | input, 173 | rois, 174 | offset, 175 | spatial_scale, 176 | pooled_size, 177 | output_dim, 178 | no_trans, 179 | group_size, 180 | part_size, 181 | sample_per_part, 182 | trans_std, 183 | ), 184 | eps=1e-4, 185 | ), 186 | ) 187 | 188 | 189 | def example_dconv(): 190 | input = torch.randn(2, 64, 128, 128) 191 | # wrap all things (offset and mask) in DCN 192 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2) 193 | # print(dcn.weight.shape, input.shape) 194 | output = dcn(input) 195 | targert = output.new(*output.size()) 196 | targert.data.uniform_(-0.01, 0.01) 197 | error = (targert - output).mean() 198 | error.backward() 199 | print(output.shape) 200 | 201 | 202 | def example_dpooling(): 203 | input = torch.randn(2, 32, 64, 64) 204 | batch_inds = torch.randint(2, (20, 1)).float() 205 | x = torch.randint(256, (20, 1)).float() 206 | y = torch.randint(256, (20, 1)).float() 207 | w = torch.randint(64, (20, 1)).float() 208 | h = torch.randint(64, (20, 1)).float() 209 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 210 | offset = torch.randn(20, 2, 7, 7) 211 | input.requires_grad = True 212 | offset.requires_grad = True 213 | 214 | # normal roi_align 215 | pooling = DCNv2Pooling( 216 | spatial_scale=1.0 / 4, 217 | pooled_size=7, 218 | output_dim=32, 219 | no_trans=True, 220 | group_size=1, 221 | trans_std=0.1, 222 | ) 223 | 224 | # deformable pooling 225 | dpooling = DCNv2Pooling( 226 | spatial_scale=1.0 / 4, 227 | pooled_size=7, 228 | output_dim=32, 229 | no_trans=False, 230 | group_size=1, 231 | trans_std=0.1, 232 | ) 233 | 234 | out = pooling(input, rois, offset) 235 | dout = dpooling(input, rois, offset) 236 | print(out.shape) 237 | print(dout.shape) 238 | 239 | target_out = out.new(*out.size()) 240 | target_out.data.uniform_(-0.01, 0.01) 241 | target_dout = dout.new(*dout.size()) 242 | target_dout.data.uniform_(-0.01, 0.01) 243 | e = (target_out - out).mean() 244 | e.backward() 245 | e = (target_dout - dout).mean() 246 | e.backward() 247 | 248 | 249 | def example_mdpooling(): 250 | input = torch.randn(2, 32, 64, 64) 251 | input.requires_grad = True 252 | batch_inds = torch.randint(2, (20, 1)).float() 253 | x = torch.randint(256, (20, 1)).float() 254 | y = torch.randint(256, (20, 1)).float() 255 | w = torch.randint(64, (20, 1)).float() 256 | h = torch.randint(64, (20, 1)).float() 257 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 258 | 259 | # mdformable pooling (V2) 260 | dpooling = DCNPooling( 261 | spatial_scale=1.0 / 4, 262 | pooled_size=7, 263 | output_dim=32, 264 | no_trans=False, 265 | group_size=1, 266 | trans_std=0.1, 267 | deform_fc_dim=1024, 268 | ) 269 | 270 | dout = dpooling(input, rois) 271 | target = dout.new(*dout.size()) 272 | target.data.uniform_(-0.1, 0.1) 273 | error = (target - dout).mean() 274 | error.backward() 275 | print(dout.shape) 276 | 277 | 278 | if __name__ == "__main__": 279 | 280 | example_dconv() 281 | example_dpooling() 282 | example_mdpooling() 283 | 284 | check_pooling_zero_offset() 285 | # zero offset check 286 | if inC == outC: 287 | check_zero_offset() 288 | 289 | check_gradient_dpooling() 290 | check_gradient_dconv() 291 | # """ 292 | # ****** Note: backward is not reentrant error may not be a serious problem, 293 | # ****** since the max error is less than 1e-7, 294 | # ****** Still looking for what trigger this problem 295 | # """ 296 | -------------------------------------------------------------------------------- /test/testcuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import gradcheck 7 | 8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling 9 | 10 | deformable_groups = 1 11 | N, inC, inH, inW = 2, 2, 4, 4 12 | outC = 2 13 | kH, kW = 3, 3 14 | 15 | 16 | def conv_identify(weight, bias): 17 | weight.data.zero_() 18 | bias.data.zero_() 19 | o, i, h, w = weight.shape 20 | y = h // 2 21 | x = w // 2 22 | for p in range(i): 23 | for q in range(o): 24 | if p == q: 25 | weight.data[q, p, y, x] = 1.0 26 | 27 | 28 | def check_zero_offset(): 29 | conv_offset = nn.Conv2d( 30 | inC, 31 | deformable_groups * 2 * kH * kW, 32 | kernel_size=(kH, kW), 33 | stride=(1, 1), 34 | padding=(1, 1), 35 | bias=True, 36 | ).cuda() 37 | 38 | conv_mask = nn.Conv2d( 39 | inC, 40 | deformable_groups * 1 * kH * kW, 41 | kernel_size=(kH, kW), 42 | stride=(1, 1), 43 | padding=(1, 1), 44 | bias=True, 45 | ).cuda() 46 | 47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() 48 | 49 | conv_offset.weight.data.zero_() 50 | conv_offset.bias.data.zero_() 51 | conv_mask.weight.data.zero_() 52 | conv_mask.bias.data.zero_() 53 | conv_identify(dcn_v2.weight, dcn_v2.bias) 54 | 55 | input = torch.randn(N, inC, inH, inW).cuda() 56 | offset = conv_offset(input) 57 | mask = conv_mask(input) 58 | mask = torch.sigmoid(mask) 59 | output = dcn_v2(input, offset, mask) 60 | output *= 2 61 | d = (input - output).abs().max() 62 | if d < 1e-10: 63 | print("Zero offset passed") 64 | else: 65 | print("Zero offset failed") 66 | print(input) 67 | print(output) 68 | 69 | 70 | def check_gradient_dconv(): 71 | 72 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01 73 | input.requires_grad = True 74 | 75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 76 | # offset.data.zero_() 77 | # offset.data -= 0.5 78 | offset.requires_grad = True 79 | 80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() 81 | # mask.data.zero_() 82 | mask.requires_grad = True 83 | mask = torch.sigmoid(mask) 84 | 85 | weight = torch.randn(outC, inC, kH, kW).cuda() 86 | weight.requires_grad = True 87 | 88 | bias = torch.rand(outC).cuda() 89 | bias.requires_grad = True 90 | 91 | stride = 1 92 | padding = 1 93 | dilation = 1 94 | 95 | print( 96 | "check_gradient_dconv: ", 97 | gradcheck( 98 | dcn_v2_conv, 99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), 100 | eps=1e-3, 101 | atol=1e-4, 102 | rtol=1e-2, 103 | ), 104 | ) 105 | 106 | 107 | def check_pooling_zero_offset(): 108 | 109 | input = torch.randn(2, 16, 64, 64).cuda().zero_() 110 | input[0, :, 16:26, 16:26] = 1.0 111 | input[1, :, 10:20, 20:30] = 2.0 112 | rois = ( 113 | torch.tensor( 114 | [ 115 | [0, 65, 65, 103, 103], 116 | [1, 81, 41, 119, 79], 117 | ] 118 | ) 119 | .cuda() 120 | .float() 121 | ) 122 | pooling = DCNv2Pooling( 123 | spatial_scale=1.0 / 4, 124 | pooled_size=7, 125 | output_dim=16, 126 | no_trans=True, 127 | group_size=1, 128 | trans_std=0.0, 129 | ).cuda() 130 | 131 | out = pooling(input, rois, input.new()) 132 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 133 | print(s) 134 | 135 | dpooling = DCNv2Pooling( 136 | spatial_scale=1.0 / 4, 137 | pooled_size=7, 138 | output_dim=16, 139 | no_trans=False, 140 | group_size=1, 141 | trans_std=0.0, 142 | ).cuda() 143 | offset = torch.randn(20, 2, 7, 7).cuda().zero_() 144 | dout = dpooling(input, rois, offset) 145 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) 146 | print(s) 147 | 148 | 149 | def check_gradient_dpooling(): 150 | input = torch.randn(2, 3, 5, 5).cuda().float() * 0.01 151 | N = 4 152 | batch_inds = torch.randint(2, (N, 1)).cuda().float() 153 | x = torch.rand((N, 1)).cuda().float() * 15 154 | y = torch.rand((N, 1)).cuda().float() * 15 155 | w = torch.rand((N, 1)).cuda().float() * 10 156 | h = torch.rand((N, 1)).cuda().float() * 10 157 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 158 | offset = torch.randn(N, 2, 3, 3).cuda() 159 | input.requires_grad = True 160 | offset.requires_grad = True 161 | 162 | spatial_scale = 1.0 / 4 163 | pooled_size = 3 164 | output_dim = 3 165 | no_trans = 0 166 | group_size = 1 167 | trans_std = 0.0 168 | sample_per_part = 4 169 | part_size = pooled_size 170 | 171 | print( 172 | "check_gradient_dpooling:", 173 | gradcheck( 174 | dcn_v2_pooling, 175 | ( 176 | input, 177 | rois, 178 | offset, 179 | spatial_scale, 180 | pooled_size, 181 | output_dim, 182 | no_trans, 183 | group_size, 184 | part_size, 185 | sample_per_part, 186 | trans_std, 187 | ), 188 | eps=1e-4, 189 | ), 190 | ) 191 | 192 | 193 | def example_dconv(): 194 | input = torch.randn(2, 64, 128, 128).cuda() 195 | # wrap all things (offset and mask) in DCN 196 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() 197 | # print(dcn.weight.shape, input.shape) 198 | output = dcn(input) 199 | targert = output.new(*output.size()) 200 | targert.data.uniform_(-0.01, 0.01) 201 | error = (targert - output).mean() 202 | error.backward() 203 | print(output.shape) 204 | 205 | 206 | def example_dpooling(): 207 | input = torch.randn(2, 32, 64, 64).cuda() 208 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 209 | x = torch.randint(256, (20, 1)).cuda().float() 210 | y = torch.randint(256, (20, 1)).cuda().float() 211 | w = torch.randint(64, (20, 1)).cuda().float() 212 | h = torch.randint(64, (20, 1)).cuda().float() 213 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 214 | offset = torch.randn(20, 2, 7, 7).cuda() 215 | input.requires_grad = True 216 | offset.requires_grad = True 217 | 218 | # normal roi_align 219 | pooling = DCNv2Pooling( 220 | spatial_scale=1.0 / 4, 221 | pooled_size=7, 222 | output_dim=32, 223 | no_trans=True, 224 | group_size=1, 225 | trans_std=0.1, 226 | ).cuda() 227 | 228 | # deformable pooling 229 | dpooling = DCNv2Pooling( 230 | spatial_scale=1.0 / 4, 231 | pooled_size=7, 232 | output_dim=32, 233 | no_trans=False, 234 | group_size=1, 235 | trans_std=0.1, 236 | ).cuda() 237 | 238 | out = pooling(input, rois, offset) 239 | dout = dpooling(input, rois, offset) 240 | print(out.shape) 241 | print(dout.shape) 242 | 243 | target_out = out.new(*out.size()) 244 | target_out.data.uniform_(-0.01, 0.01) 245 | target_dout = dout.new(*dout.size()) 246 | target_dout.data.uniform_(-0.01, 0.01) 247 | e = (target_out - out).mean() 248 | e.backward() 249 | e = (target_dout - dout).mean() 250 | e.backward() 251 | 252 | 253 | def example_mdpooling(): 254 | input = torch.randn(2, 32, 64, 64).cuda() 255 | input.requires_grad = True 256 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 257 | x = torch.randint(256, (20, 1)).cuda().float() 258 | y = torch.randint(256, (20, 1)).cuda().float() 259 | w = torch.randint(64, (20, 1)).cuda().float() 260 | h = torch.randint(64, (20, 1)).cuda().float() 261 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 262 | 263 | # mdformable pooling (V2) 264 | dpooling = DCNPooling( 265 | spatial_scale=1.0 / 4, 266 | pooled_size=7, 267 | output_dim=32, 268 | no_trans=False, 269 | group_size=1, 270 | trans_std=0.1, 271 | deform_fc_dim=1024, 272 | ).cuda() 273 | 274 | dout = dpooling(input, rois) 275 | target = dout.new(*dout.size()) 276 | target.data.uniform_(-0.1, 0.1) 277 | error = (target - dout).mean() 278 | error.backward() 279 | print(dout.shape) 280 | 281 | 282 | if __name__ == "__main__": 283 | 284 | example_dconv() 285 | example_dpooling() 286 | example_mdpooling() 287 | 288 | check_pooling_zero_offset() 289 | # zero offset check 290 | if inC == outC: 291 | check_zero_offset() 292 | 293 | check_gradient_dpooling() 294 | check_gradient_dconv() 295 | # """ 296 | # ****** Note: backward is not reentrant error may not be a serious problem, 297 | # ****** since the max error is less than 1e-7, 298 | # ****** Still looking for what trigger this problem 299 | # """ 300 | --------------------------------------------------------------------------------