├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── dcn_v2.py
├── make.sh
├── setup.py
├── src
├── cpu
│ ├── dcn_v2_cpu.cpp
│ ├── dcn_v2_im2col_cpu.cpp
│ ├── dcn_v2_im2col_cpu.h
│ ├── dcn_v2_psroi_pooling_cpu.cpp
│ └── vision.h
├── cuda
│ ├── dcn_v2_cuda.cu
│ ├── dcn_v2_im2col_cuda.cu
│ ├── dcn_v2_im2col_cuda.h
│ ├── dcn_v2_psroi_pooling_cuda.cu
│ └── vision.h
├── dcn_v2.h
└── vision.cpp
└── test
├── test.py
├── testcpu.py
└── testcuda.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .idea
3 | *.so
4 | *.o
5 | *pyc
6 | _ext
7 | build
8 | DCNv2.egg-info
9 | dist
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Charles Shang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0
2 |
3 | ### Build
4 | ```bash
5 | ./make.sh # build
6 | python test.py # run examples and gradient check
7 | ```
8 |
9 | ### An Example
10 | - deformable conv
11 | ```python
12 | from dcn_v2 import DCN
13 | input = torch.randn(2, 64, 128, 128).cuda()
14 | # wrap all things (offset and mask) in DCN
15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
16 | output = dcn(input)
17 | print(output.shape)
18 | ```
19 | - deformable roi pooling
20 | ```python
21 | from dcn_v2 import DCNPooling
22 | input = torch.randn(2, 32, 64, 64).cuda()
23 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
24 | x = torch.randint(256, (20, 1)).cuda().float()
25 | y = torch.randint(256, (20, 1)).cuda().float()
26 | w = torch.randint(64, (20, 1)).cuda().float()
27 | h = torch.randint(64, (20, 1)).cuda().float()
28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
29 |
30 | # mdformable pooling (V2)
31 | # wrap all things (offset and mask) in DCNPooling
32 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
33 | pooled_size=7,
34 | output_dim=32,
35 | no_trans=False,
36 | group_size=1,
37 | trans_std=0.1).cuda()
38 |
39 | dout = dpooling(input, rois)
40 | ```
41 | ### Note
42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
43 | ```bash
44 | git checkout pytorch_0.4
45 | ```
46 |
47 | ### Known Issues:
48 |
49 | - [x] Gradient check w.r.t offset (solved)
50 | - [ ] Backward is not reentrant (minor)
51 |
52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
53 |
54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
56 | non-differential points?
57 |
58 | Update: all gradient check passes with double precision.
59 |
60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
61 | float `<1e-15` for double),
62 | so it may not be a serious problem (?)
63 |
64 | Please post an issue or PR if you have any comments.
65 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ifzhang/DCNv2/ab4d98efc0aafeb27bb05b803820385401d9921b/__init__.py
--------------------------------------------------------------------------------
/dcn_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import math
5 |
6 | import torch
7 | from torch import nn
8 | from torch.autograd import Function
9 | from torch.autograd.function import once_differentiable
10 | from torch.nn.modules.utils import _pair
11 |
12 | import _ext as _backend
13 |
14 |
15 | class _DCNv2(Function):
16 | @staticmethod
17 | def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups):
18 | ctx.stride = _pair(stride)
19 | ctx.padding = _pair(padding)
20 | ctx.dilation = _pair(dilation)
21 | ctx.kernel_size = _pair(weight.shape[2:4])
22 | ctx.deformable_groups = deformable_groups
23 | output = _backend.dcn_v2_forward(
24 | input,
25 | weight,
26 | bias,
27 | offset,
28 | mask,
29 | ctx.kernel_size[0],
30 | ctx.kernel_size[1],
31 | ctx.stride[0],
32 | ctx.stride[1],
33 | ctx.padding[0],
34 | ctx.padding[1],
35 | ctx.dilation[0],
36 | ctx.dilation[1],
37 | ctx.deformable_groups,
38 | )
39 | ctx.save_for_backward(input, offset, mask, weight, bias)
40 | return output
41 |
42 | @staticmethod
43 | @once_differentiable
44 | def backward(ctx, grad_output):
45 | input, offset, mask, weight, bias = ctx.saved_tensors
46 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward(
47 | input,
48 | weight,
49 | bias,
50 | offset,
51 | mask,
52 | grad_output,
53 | ctx.kernel_size[0],
54 | ctx.kernel_size[1],
55 | ctx.stride[0],
56 | ctx.stride[1],
57 | ctx.padding[0],
58 | ctx.padding[1],
59 | ctx.dilation[0],
60 | ctx.dilation[1],
61 | ctx.deformable_groups,
62 | )
63 |
64 | return (
65 | grad_input,
66 | grad_offset,
67 | grad_mask,
68 | grad_weight,
69 | grad_bias,
70 | None,
71 | None,
72 | None,
73 | None,
74 | )
75 |
76 |
77 | dcn_v2_conv = _DCNv2.apply
78 |
79 |
80 | class DCNv2(nn.Module):
81 | def __init__(
82 | self,
83 | in_channels,
84 | out_channels,
85 | kernel_size,
86 | stride,
87 | padding,
88 | dilation=1,
89 | deformable_groups=1,
90 | ):
91 | super(DCNv2, self).__init__()
92 | self.in_channels = in_channels
93 | self.out_channels = out_channels
94 | self.kernel_size = _pair(kernel_size)
95 | self.stride = _pair(stride)
96 | self.padding = _pair(padding)
97 | self.dilation = _pair(dilation)
98 | self.deformable_groups = deformable_groups
99 |
100 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
101 | self.bias = nn.Parameter(torch.Tensor(out_channels))
102 | self.reset_parameters()
103 |
104 | def reset_parameters(self):
105 | n = self.in_channels
106 | for k in self.kernel_size:
107 | n *= k
108 | stdv = 1.0 / math.sqrt(n)
109 | self.weight.data.uniform_(-stdv, stdv)
110 | self.bias.data.zero_()
111 |
112 | def forward(self, input, offset, mask):
113 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1]
114 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]
115 | return dcn_v2_conv(
116 | input,
117 | offset,
118 | mask,
119 | self.weight,
120 | self.bias,
121 | self.stride,
122 | self.padding,
123 | self.dilation,
124 | self.deformable_groups,
125 | )
126 |
127 |
128 | class DCN(DCNv2):
129 | def __init__(
130 | self,
131 | in_channels,
132 | out_channels,
133 | kernel_size,
134 | stride,
135 | padding,
136 | dilation=1,
137 | deformable_groups=1,
138 | ):
139 | super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups)
140 |
141 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
142 | self.conv_offset_mask = nn.Conv2d(
143 | self.in_channels,
144 | channels_,
145 | kernel_size=self.kernel_size,
146 | stride=self.stride,
147 | padding=self.padding,
148 | bias=True,
149 | )
150 | self.init_offset()
151 |
152 | def init_offset(self):
153 | self.conv_offset_mask.weight.data.zero_()
154 | self.conv_offset_mask.bias.data.zero_()
155 |
156 | def forward(self, input):
157 | out = self.conv_offset_mask(input)
158 | o1, o2, mask = torch.chunk(out, 3, dim=1)
159 | offset = torch.cat((o1, o2), dim=1)
160 | mask = torch.sigmoid(mask)
161 | return dcn_v2_conv(
162 | input,
163 | offset,
164 | mask,
165 | self.weight,
166 | self.bias,
167 | self.stride,
168 | self.padding,
169 | self.dilation,
170 | self.deformable_groups,
171 | )
172 |
173 |
174 | class _DCNv2Pooling(Function):
175 | @staticmethod
176 | def forward(
177 | ctx,
178 | input,
179 | rois,
180 | offset,
181 | spatial_scale,
182 | pooled_size,
183 | output_dim,
184 | no_trans,
185 | group_size=1,
186 | part_size=None,
187 | sample_per_part=4,
188 | trans_std=0.0,
189 | ):
190 | ctx.spatial_scale = spatial_scale
191 | ctx.no_trans = int(no_trans)
192 | ctx.output_dim = output_dim
193 | ctx.group_size = group_size
194 | ctx.pooled_size = pooled_size
195 | ctx.part_size = pooled_size if part_size is None else part_size
196 | ctx.sample_per_part = sample_per_part
197 | ctx.trans_std = trans_std
198 |
199 | output, output_count = _backend.dcn_v2_psroi_pooling_forward(
200 | input,
201 | rois,
202 | offset,
203 | ctx.no_trans,
204 | ctx.spatial_scale,
205 | ctx.output_dim,
206 | ctx.group_size,
207 | ctx.pooled_size,
208 | ctx.part_size,
209 | ctx.sample_per_part,
210 | ctx.trans_std,
211 | )
212 | ctx.save_for_backward(input, rois, offset, output_count)
213 | return output
214 |
215 | @staticmethod
216 | @once_differentiable
217 | def backward(ctx, grad_output):
218 | input, rois, offset, output_count = ctx.saved_tensors
219 | grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward(
220 | grad_output,
221 | input,
222 | rois,
223 | offset,
224 | output_count,
225 | ctx.no_trans,
226 | ctx.spatial_scale,
227 | ctx.output_dim,
228 | ctx.group_size,
229 | ctx.pooled_size,
230 | ctx.part_size,
231 | ctx.sample_per_part,
232 | ctx.trans_std,
233 | )
234 |
235 | return grad_input, None, grad_offset, None, None, None, None, None, None, None, None
236 |
237 |
238 | dcn_v2_pooling = _DCNv2Pooling.apply
239 |
240 |
241 | class DCNv2Pooling(nn.Module):
242 | def __init__(
243 | self,
244 | spatial_scale,
245 | pooled_size,
246 | output_dim,
247 | no_trans,
248 | group_size=1,
249 | part_size=None,
250 | sample_per_part=4,
251 | trans_std=0.0,
252 | ):
253 | super(DCNv2Pooling, self).__init__()
254 | self.spatial_scale = spatial_scale
255 | self.pooled_size = pooled_size
256 | self.output_dim = output_dim
257 | self.no_trans = no_trans
258 | self.group_size = group_size
259 | self.part_size = pooled_size if part_size is None else part_size
260 | self.sample_per_part = sample_per_part
261 | self.trans_std = trans_std
262 |
263 | def forward(self, input, rois, offset):
264 | assert input.shape[1] == self.output_dim
265 | if self.no_trans:
266 | offset = input.new()
267 | return dcn_v2_pooling(
268 | input,
269 | rois,
270 | offset,
271 | self.spatial_scale,
272 | self.pooled_size,
273 | self.output_dim,
274 | self.no_trans,
275 | self.group_size,
276 | self.part_size,
277 | self.sample_per_part,
278 | self.trans_std,
279 | )
280 |
281 |
282 | class DCNPooling(DCNv2Pooling):
283 | def __init__(
284 | self,
285 | spatial_scale,
286 | pooled_size,
287 | output_dim,
288 | no_trans,
289 | group_size=1,
290 | part_size=None,
291 | sample_per_part=4,
292 | trans_std=0.0,
293 | deform_fc_dim=1024,
294 | ):
295 | super(DCNPooling, self).__init__(
296 | spatial_scale,
297 | pooled_size,
298 | output_dim,
299 | no_trans,
300 | group_size,
301 | part_size,
302 | sample_per_part,
303 | trans_std,
304 | )
305 |
306 | self.deform_fc_dim = deform_fc_dim
307 |
308 | if not no_trans:
309 | self.offset_mask_fc = nn.Sequential(
310 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim),
311 | nn.ReLU(inplace=True),
312 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
313 | nn.ReLU(inplace=True),
314 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3),
315 | )
316 | self.offset_mask_fc[4].weight.data.zero_()
317 | self.offset_mask_fc[4].bias.data.zero_()
318 |
319 | def forward(self, input, rois):
320 | offset = input.new()
321 |
322 | if not self.no_trans:
323 |
324 | # do roi_align first
325 | n = rois.shape[0]
326 | roi = dcn_v2_pooling(
327 | input,
328 | rois,
329 | offset,
330 | self.spatial_scale,
331 | self.pooled_size,
332 | self.output_dim,
333 | True, # no trans
334 | self.group_size,
335 | self.part_size,
336 | self.sample_per_part,
337 | self.trans_std,
338 | )
339 |
340 | # build mask and offset
341 | offset_mask = self.offset_mask_fc(roi.view(n, -1))
342 | offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)
343 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
344 | offset = torch.cat((o1, o2), dim=1)
345 | mask = torch.sigmoid(mask)
346 |
347 | # do pooling with offset and mask
348 | return (
349 | dcn_v2_pooling(
350 | input,
351 | rois,
352 | offset,
353 | self.spatial_scale,
354 | self.pooled_size,
355 | self.output_dim,
356 | self.no_trans,
357 | self.group_size,
358 | self.part_size,
359 | self.sample_per_part,
360 | self.trans_std,
361 | )
362 | * mask
363 | )
364 | # only roi_align
365 | return dcn_v2_pooling(
366 | input,
367 | rois,
368 | offset,
369 | self.spatial_scale,
370 | self.pooled_size,
371 | self.output_dim,
372 | self.no_trans,
373 | self.group_size,
374 | self.part_size,
375 | self.sample_per_part,
376 | self.trans_std,
377 | )
378 |
--------------------------------------------------------------------------------
/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python setup.py build develop
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import glob
4 | import os
5 | import sys
6 |
7 | import torch
8 | from setuptools import find_packages, setup
9 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
10 |
11 | requirements = ["torch", "torchvision"]
12 |
13 |
14 | def get_extensions():
15 | this_dir = os.path.dirname(os.path.abspath(__file__))
16 | extensions_dir = os.path.join(this_dir, "src")
17 |
18 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
19 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
20 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
21 |
22 | os.environ["CC"] = "g++"
23 | sources = main_file + source_cpu
24 | extension = CppExtension
25 | extra_compile_args = {"cxx": []}
26 | define_macros = []
27 |
28 | if torch.cuda.is_available() and CUDA_HOME is not None:
29 | extension = CUDAExtension
30 | sources += source_cuda
31 | define_macros += [("WITH_CUDA", None)]
32 | extra_compile_args["nvcc"] = [
33 | "-DCUDA_HAS_FP16=1",
34 | "-D__CUDA_NO_HALF_OPERATORS__",
35 | "-D__CUDA_NO_HALF_CONVERSIONS__",
36 | "-D__CUDA_NO_HALF2_OPERATORS__",
37 | ]
38 | else:
39 | # raise NotImplementedError('Cuda is not available')
40 | pass
41 |
42 | extra_compile_args['cxx'].append('-fopenmp')
43 |
44 | sources = [os.path.join(extensions_dir, s) for s in sources]
45 | include_dirs = [extensions_dir]
46 | ext_modules = [
47 | extension(
48 | "_ext",
49 | sources,
50 | include_dirs=include_dirs,
51 | define_macros=define_macros,
52 | extra_compile_args=extra_compile_args,
53 | )
54 | ]
55 | return ext_modules
56 |
57 |
58 | setup(
59 | name="DCNv2",
60 | version="0.1",
61 | author="charlesshang",
62 | url="https://github.com/charlesshang/DCNv2",
63 | description="deformable convolutional networks",
64 | packages=find_packages(
65 | exclude=(
66 | "configs",
67 | "tests",
68 | )
69 | ),
70 | # install_requires=requirements,
71 | ext_modules=get_extensions(),
72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
73 | )
74 |
--------------------------------------------------------------------------------
/src/cpu/dcn_v2_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "cpu/dcn_v2_im2col_cpu.h"
3 | #include
4 |
5 | #include
6 | //#include
7 |
8 | #include
9 | //#include
10 | //#include
11 |
12 | //extern THCState *state;
13 |
14 | // author: Charles Shang
15 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
16 |
17 | // modified from the CUDA version for CPU use by Daniel K. Suhendro
18 |
19 | // edit by: James Bockman and Matthew Howe
20 | // modified for torch implementation to remove use of deprecated torch access to Blas
21 |
22 | at::Tensor
23 | dcn_v2_cpu_forward(const at::Tensor &input,
24 | const at::Tensor &weight,
25 | const at::Tensor &bias,
26 | const at::Tensor &offset,
27 | const at::Tensor &mask,
28 | const int kernel_h,
29 | const int kernel_w,
30 | const int stride_h,
31 | const int stride_w,
32 | const int pad_h,
33 | const int pad_w,
34 | const int dilation_h,
35 | const int dilation_w,
36 | const int deformable_group)
37 | {
38 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
39 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
40 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor");
41 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor");
42 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor");
43 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/
44 |
45 | const int batch = input.size(0);
46 | const int channels = input.size(1);
47 | const int height = input.size(2);
48 | const int width = input.size(3);
49 |
50 | const int channels_out = weight.size(0);
51 | const int channels_kernel = weight.size(1);
52 | const int kernel_h_ = weight.size(2);
53 | const int kernel_w_ = weight.size(3);
54 |
55 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
56 | // printf("Channels: %d %d\n", channels, channels_kernel);
57 | // printf("Channels: %d %d\n", channels_out, channels_kernel);
58 |
59 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
60 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
61 |
62 | AT_ASSERTM(channels == channels_kernel,
63 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
64 |
65 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
66 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
67 |
68 | // auto ones = at::ones({height_out, width_out}, input.options());
69 | auto ones = at::ones({bias.sizes()[0], height_out, width_out}, input.options());
70 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
71 | auto output = at::zeros({batch, channels_out, height_out, width_out}, input.options());
72 |
73 | using scalar_t = float;
74 | for (int b = 0; b < batch; b++)
75 | {
76 | auto input_n = input.select(0, b);
77 | auto offset_n = offset.select(0, b);
78 | auto mask_n = mask.select(0, b);
79 | auto output_n = output.select(0, b);
80 | // std::cout << "output_n: " << output_n << "output.select(0,b): " << output.select(0,b) << "\n";
81 |
82 | // Do Bias first:
83 | // M,N,K are dims of matrix A and B
84 | // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
85 | // (N x 1) (1 x M)
86 |
87 | // torch implementation
88 | auto ones_T = at::transpose(ones.contiguous(), 2, 0);
89 | ones_T = at::mul(ones_T, bias.contiguous());
90 | ones_T = at::transpose(ones_T, 2, 0);
91 | output_n = at::add(output_n, ones_T);
92 |
93 | modulated_deformable_im2col_cpu(input_n.data_ptr(),
94 | offset_n.data_ptr(),
95 | mask_n.data_ptr(),
96 | 1, channels, height, width,
97 | height_out, width_out, kernel_h, kernel_w,
98 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
99 | deformable_group,
100 | columns.data_ptr());
101 |
102 | //(k * m) x (m * n)
103 | // Y = WC
104 |
105 | // torch implementation
106 | auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w});
107 | auto product = at::matmul(weight_flat, columns);
108 | output.select(0, b) = at::add(output_n, product.view({channels_out, height_out, width_out}));
109 | }
110 | return output;
111 | }
112 |
113 | std::vector dcn_v2_cpu_backward(const at::Tensor &input,
114 | const at::Tensor &weight,
115 | const at::Tensor &bias,
116 | const at::Tensor &offset,
117 | const at::Tensor &mask,
118 | const at::Tensor &grad_output,
119 | int kernel_h, int kernel_w,
120 | int stride_h, int stride_w,
121 | int pad_h, int pad_w,
122 | int dilation_h, int dilation_w,
123 | int deformable_group)
124 | {
125 |
126 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
127 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
128 |
129 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
130 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor");
131 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor");
132 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor");
133 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/
134 |
135 | const int batch = input.size(0);
136 | const int channels = input.size(1);
137 | const int height = input.size(2);
138 | const int width = input.size(3);
139 |
140 | const int channels_out = weight.size(0);
141 | const int channels_kernel = weight.size(1);
142 | const int kernel_h_ = weight.size(2);
143 | const int kernel_w_ = weight.size(3);
144 |
145 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
146 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
147 |
148 | AT_ASSERTM(channels == channels_kernel,
149 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
150 |
151 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
152 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
153 |
154 | auto ones = at::ones({height_out, width_out}, input.options());
155 | auto columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
156 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
157 |
158 | auto grad_input = at::zeros_like(input);
159 | auto grad_weight = at::zeros_like(weight);
160 | auto grad_bias = at::zeros_like(bias);
161 | auto grad_offset = at::zeros_like(offset);
162 | auto grad_mask = at::zeros_like(mask);
163 |
164 | using scalar_t = float;
165 |
166 | for (int b = 0; b < batch; b++)
167 | {
168 | auto input_n = input.select(0, b);
169 | auto offset_n = offset.select(0, b);
170 | auto mask_n = mask.select(0, b);
171 | auto grad_output_n = grad_output.select(0, b);
172 | auto grad_input_n = grad_input.select(0, b);
173 | auto grad_offset_n = grad_offset.select(0, b);
174 | auto grad_mask_n = grad_mask.select(0, b);
175 |
176 |
177 |
178 | // Torch implementation
179 | auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w});
180 | weight_flat = at::transpose(weight_flat, 1, 0);
181 | auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out});
182 | columns = at::matmul(weight_flat, grad_output_n_flat);
183 |
184 | // gradient w.r.t. input coordinate data
185 | modulated_deformable_col2im_coord_cpu(columns.data_ptr(),
186 | input_n.data_ptr(),
187 | offset_n.data_ptr(),
188 | mask_n.data_ptr(),
189 | 1, channels, height, width,
190 | height_out, width_out, kernel_h, kernel_w,
191 | pad_h, pad_w, stride_h, stride_w,
192 | dilation_h, dilation_w, deformable_group,
193 | grad_offset_n.data_ptr(),
194 | grad_mask_n.data_ptr());
195 | // gradient w.r.t. input data
196 | modulated_deformable_col2im_cpu(columns.data_ptr(),
197 | offset_n.data_ptr(),
198 | mask_n.data_ptr(),
199 | 1, channels, height, width,
200 | height_out, width_out, kernel_h, kernel_w,
201 | pad_h, pad_w, stride_h, stride_w,
202 | dilation_h, dilation_w, deformable_group,
203 | grad_input_n.data_ptr());
204 |
205 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group
206 | modulated_deformable_im2col_cpu(input_n.data_ptr(),
207 | offset_n.data_ptr(),
208 | mask_n.data_ptr(),
209 | 1, channels, height, width,
210 | height_out, width_out, kernel_h, kernel_w,
211 | pad_h, pad_w, stride_h, stride_w,
212 | dilation_h, dilation_w, deformable_group,
213 | columns.data_ptr());
214 |
215 | // Torch implementation
216 | auto product = at::matmul(grad_output_n_flat, at::transpose(columns, 1, 0));
217 | grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w}));
218 |
219 |
220 | // Torch implementation
221 | auto ones_flat = ones.view({height_out*width_out});
222 | product = at::matmul(grad_output_n_flat, ones_flat);
223 | grad_bias = at::add(grad_bias, product);
224 | }
225 |
226 | return {
227 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias
228 | };
229 | }
230 |
--------------------------------------------------------------------------------
/src/cpu/dcn_v2_im2col_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include "dcn_v2_im2col_cpu.h"
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 | //#include
8 |
9 | #include
10 | //#include
11 | //#include
12 |
13 | // modified from the CUDA version for CPU use by Daniel K. Suhendro
14 |
15 | /*#define CUDA_KERNEL_LOOP(i, n) \
16 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
17 | i < (n); \
18 | i += blockDim.x * gridDim.x)
19 |
20 | const int CUDA_NUM_THREADS = 1024;
21 | inline int GET_BLOCKS(const int N)
22 | {
23 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
24 | }*/
25 |
26 |
27 | float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width,
28 | const int height, const int width, float h, float w)
29 | {
30 | int h_low = floor(h);
31 | int w_low = floor(w);
32 | int h_high = h_low + 1;
33 | int w_high = w_low + 1;
34 |
35 | float lh = h - h_low;
36 | float lw = w - w_low;
37 | float hh = 1 - lh, hw = 1 - lw;
38 |
39 | float v1 = 0;
40 | if (h_low >= 0 && w_low >= 0)
41 | v1 = bottom_data[h_low * data_width + w_low];
42 | float v2 = 0;
43 | if (h_low >= 0 && w_high <= width - 1)
44 | v2 = bottom_data[h_low * data_width + w_high];
45 | float v3 = 0;
46 | if (h_high <= height - 1 && w_low >= 0)
47 | v3 = bottom_data[h_high * data_width + w_low];
48 | float v4 = 0;
49 | if (h_high <= height - 1 && w_high <= width - 1)
50 | v4 = bottom_data[h_high * data_width + w_high];
51 |
52 | float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
53 |
54 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
55 | return val;
56 | }
57 |
58 | float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,
59 | const int h, const int w, const int height, const int width)
60 | {
61 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
62 | {
63 | //empty
64 | return 0;
65 | }
66 |
67 | int argmax_h_low = floor(argmax_h);
68 | int argmax_w_low = floor(argmax_w);
69 | int argmax_h_high = argmax_h_low + 1;
70 | int argmax_w_high = argmax_w_low + 1;
71 |
72 | float weight = 0;
73 | if (h == argmax_h_low && w == argmax_w_low)
74 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
75 | if (h == argmax_h_low && w == argmax_w_high)
76 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
77 | if (h == argmax_h_high && w == argmax_w_low)
78 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
79 | if (h == argmax_h_high && w == argmax_w_high)
80 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
81 | return weight;
82 | }
83 |
84 | float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,
85 | const int height, const int width, const float *im_data,
86 | const int data_width, const int bp_dir)
87 | {
88 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
89 | {
90 | //empty
91 | return 0;
92 | }
93 |
94 | int argmax_h_low = floor(argmax_h);
95 | int argmax_w_low = floor(argmax_w);
96 | int argmax_h_high = argmax_h_low + 1;
97 | int argmax_w_high = argmax_w_low + 1;
98 |
99 | float weight = 0;
100 |
101 | if (bp_dir == 0)
102 | {
103 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
104 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
105 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
106 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
107 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
108 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
109 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
110 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
111 | }
112 | else if (bp_dir == 1)
113 | {
114 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
115 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
116 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
117 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
118 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
119 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
120 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
121 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
122 | }
123 |
124 | return weight;
125 | }
126 |
127 | void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask,
128 | const int height, const int width, const int kernel_h, const int kernel_w,
129 | const int pad_h, const int pad_w,
130 | const int stride_h, const int stride_w,
131 | const int dilation_h, const int dilation_w,
132 | const int channel_per_deformable_group,
133 | const int batch_size, const int num_channels, const int deformable_group,
134 | const int height_col, const int width_col,
135 | float *data_col)
136 | {
137 | // launch channels * batch_size * height_col * width_col cores
138 | for(int index=0; index(0);
178 | const float h_im = h_in + i * dilation_h + offset_h;
179 | const float w_im = w_in + j * dilation_w + offset_w;
180 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
181 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
182 | {
183 | //const float map_h = i * dilation_h + offset_h;
184 | //const float map_w = j * dilation_w + offset_w;
185 | //const int cur_height = height - h_in;
186 | //const int cur_width = width - w_in;
187 | //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
188 | val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);
189 | }
190 | *data_col_ptr = val * mask;
191 | // data_col_ptr += batch_size * height_col * width_col;
192 | data_col_ptr += height_col * width_col;
193 | }
194 | }
195 | }
196 | }
197 |
198 | void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask,
199 | const int channels, const int height, const int width,
200 | const int kernel_h, const int kernel_w,
201 | const int pad_h, const int pad_w,
202 | const int stride_h, const int stride_w,
203 | const int dilation_h, const int dilation_w,
204 | const int channel_per_deformable_group,
205 | const int batch_size, const int deformable_group,
206 | const int height_col, const int width_col,
207 | float *grad_im)
208 | {
209 | for(int index = 0; index < n; index++)
210 | {
211 | const int j = (index / width_col / height_col / batch_size) % kernel_w;
212 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
213 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
214 | // compute the start and end of the output
215 |
216 | const int deformable_group_index = c / channel_per_deformable_group;
217 |
218 | int w_out = index % width_col;
219 | int h_out = (index / width_col) % height_col;
220 | int b = (index / width_col / height_col) % batch_size;
221 | int w_in = w_out * stride_w - pad_w;
222 | int h_in = h_out * stride_h - pad_h;
223 |
224 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
225 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
226 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
227 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
228 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
229 | const float offset_h = data_offset_ptr[data_offset_h_ptr];
230 | const float offset_w = data_offset_ptr[data_offset_w_ptr];
231 | const float mask = data_mask_ptr[data_mask_hw_ptr];
232 | const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
233 | const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
234 |
235 | const float cur_top_grad = data_col[index] * mask;
236 | const int cur_h = (int)cur_inv_h_data;
237 | const int cur_w = (int)cur_inv_w_data;
238 |
239 | for (int dy = -2; dy <= 2; dy++)
240 | {
241 | for (int dx = -2; dx <= 2; dx++)
242 | {
243 | if (cur_h + dy >= 0 && cur_h + dy < height &&
244 | cur_w + dx >= 0 && cur_w + dx < width &&
245 | abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
246 | abs(cur_inv_w_data - (cur_w + dx)) < 1)
247 | {
248 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
249 | float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
250 | //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
251 | *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
252 |
253 | }
254 | }
255 | }
256 | }
257 | }
258 |
259 | void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im,
260 | const float *data_offset, const float *data_mask,
261 | const int channels, const int height, const int width,
262 | const int kernel_h, const int kernel_w,
263 | const int pad_h, const int pad_w,
264 | const int stride_h, const int stride_w,
265 | const int dilation_h, const int dilation_w,
266 | const int channel_per_deformable_group,
267 | const int batch_size, const int offset_channels, const int deformable_group,
268 | const int height_col, const int width_col,
269 | float *grad_offset, float *grad_mask)
270 | {
271 | for(int index = 0; index < n; index++)
272 | {
273 | float val = 0, mval = 0;
274 | int w = index % width_col;
275 | int h = (index / width_col) % height_col;
276 | int c = (index / width_col / height_col) % offset_channels;
277 | int b = (index / width_col / height_col) / offset_channels;
278 | // compute the start and end of the output
279 |
280 | const int deformable_group_index = c / (2 * kernel_h * kernel_w);
281 | const int col_step = kernel_h * kernel_w;
282 | int cnt = 0;
283 | const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
284 | const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
285 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
286 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
287 |
288 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
289 |
290 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
291 | {
292 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
293 | const int bp_dir = offset_c % 2;
294 |
295 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
296 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
297 | int w_out = col_pos % width_col;
298 | int h_out = (col_pos / width_col) % height_col;
299 | int w_in = w_out * stride_w - pad_w;
300 | int h_in = h_out * stride_h - pad_h;
301 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
302 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
303 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
304 | const float offset_h = data_offset_ptr[data_offset_h_ptr];
305 | const float offset_w = data_offset_ptr[data_offset_w_ptr];
306 | const float mask = data_mask_ptr[data_mask_hw_ptr];
307 | float inv_h = h_in + i * dilation_h + offset_h;
308 | float inv_w = w_in + j * dilation_w + offset_w;
309 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
310 | {
311 | inv_h = inv_w = -2;
312 | }
313 | else
314 | {
315 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
316 | }
317 | const float weight = dmcn_get_coordinate_weight_cpu(
318 | inv_h, inv_w,
319 | height, width, data_im_ptr + cnt * height * width, width, bp_dir);
320 | val += weight * data_col_ptr[col_pos] * mask;
321 | cnt += 1;
322 | }
323 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
324 | grad_offset[index] = val;
325 | if (offset_c % 2 == 0)
326 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
327 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
328 | }
329 | }
330 |
331 | void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask,
332 | const int batch_size, const int channels, const int height_im, const int width_im,
333 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
334 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
335 | const int dilation_h, const int dilation_w,
336 | const int deformable_group, float* data_col) {
337 | // num_axes should be smaller than block size
338 | const int channel_per_deformable_group = channels / deformable_group;
339 | const int num_kernels = channels * batch_size * height_col * width_col;
340 | modulated_deformable_im2col_cpu_kernel(
341 | num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
342 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
343 | batch_size, channels, deformable_group, height_col, width_col, data_col);
344 |
345 | /*cudaError_t err = cudaGetLastError();
346 | if (err != cudaSuccess)
347 | {
348 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
349 | }*/
350 |
351 | }
352 |
353 | void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask,
354 | const int batch_size, const int channels, const int height_im, const int width_im,
355 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
356 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
357 | const int dilation_h, const int dilation_w,
358 | const int deformable_group, float* grad_im){
359 |
360 | const int channel_per_deformable_group = channels / deformable_group;
361 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
362 | modulated_deformable_col2im_cpu_kernel(
363 | num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
364 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
365 | dilation_h, dilation_w, channel_per_deformable_group,
366 | batch_size, deformable_group, height_col, width_col, grad_im);
367 | /*cudaError_t err = cudaGetLastError();
368 | if (err != cudaSuccess)
369 | {
370 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
371 | }*/
372 |
373 | }
374 |
375 | void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
376 | const int batch_size, const int channels, const int height_im, const int width_im,
377 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
378 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
379 | const int dilation_h, const int dilation_w,
380 | const int deformable_group,
381 | float* grad_offset, float* grad_mask) {
382 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
383 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
384 | modulated_deformable_col2im_coord_cpu_kernel(
385 | num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
386 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
387 | dilation_h, dilation_w, channel_per_deformable_group,
388 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
389 | grad_offset, grad_mask);
390 | /*cudaError_t err = cudaGetLastError();
391 | if (err != cudaSuccess)
392 | {
393 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
394 | }*/
395 | }
--------------------------------------------------------------------------------
/src/cpu/dcn_v2_im2col_cpu.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro
64 |
65 | #ifndef DCN_V2_IM2COL_CPU
66 | #define DCN_V2_IM2COL_CPU
67 |
68 | #ifdef __cplusplus
69 | extern "C"
70 | {
71 | #endif
72 |
73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,
81 | const int batch_size, const int channels, const int height_im, const int width_im,
82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
84 | const int dilation_h, const int dilation_w,
85 | const int deformable_group, float *grad_im);
86 |
87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
88 | const int batch_size, const int channels, const int height_im, const int width_im,
89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
91 | const int dilation_h, const int dilation_w,
92 | const int deformable_group,
93 | float *grad_offset, float *grad_mask);
94 |
95 | #ifdef __cplusplus
96 | }
97 | #endif
98 |
99 | #endif
--------------------------------------------------------------------------------
/src/cpu/dcn_v2_psroi_pooling_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 | // modified from the CUDA version for CPU use by Daniel K. Suhendro
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | //#include
17 |
18 | #include
19 | //#include
20 | //#include
21 |
22 | /*#define CUDA_KERNEL_LOOP(i, n) \
23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
24 | i < (n); \
25 | i += blockDim.x * gridDim.x)
26 |
27 | const int CUDA_NUM_THREADS = 1024;
28 | inline int GET_BLOCKS(const int N)
29 | {
30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
31 | }*/
32 |
33 | template
34 | T bilinear_interp_cpu(
35 | const T *data,
36 | const T x,
37 | const T y,
38 | const int width,
39 | const int height)
40 | {
41 | int x1 = floor(x);
42 | int x2 = ceil(x);
43 | int y1 = floor(y);
44 | int y2 = ceil(y);
45 | T dist_x = static_cast(x - x1);
46 | T dist_y = static_cast(y - y1);
47 | T value11 = data[y1 * width + x1];
48 | T value12 = data[y2 * width + x1];
49 | T value21 = data[y1 * width + x2];
50 | T value22 = data[y2 * width + x2];
51 | T value = (1 - dist_x) * (1 - dist_y) * value11 +
52 | (1 - dist_x) * dist_y * value12 +
53 | dist_x * (1 - dist_y) * value21 +
54 | dist_x * dist_y * value22;
55 | return value;
56 | }
57 |
58 | template
59 | void DeformablePSROIPoolForwardKernelCpu(
60 | const int count,
61 | const T *bottom_data,
62 | const T spatial_scale,
63 | const int channels,
64 | const int height, const int width,
65 | const int pooled_height, const int pooled_width,
66 | const T *bottom_rois, const T *bottom_trans,
67 | const int no_trans,
68 | const T trans_std,
69 | const int sample_per_part,
70 | const int output_dim,
71 | const int group_size,
72 | const int part_size,
73 | const int num_classes,
74 | const int channels_each_class,
75 | T *top_data,
76 | T *top_count)
77 | {
78 | for(int index = 0; index < count; index++)
79 | {
80 | // The output is in order (n, ctop, ph, pw)
81 | int pw = index % pooled_width;
82 | int ph = (index / pooled_width) % pooled_height;
83 | int ctop = (index / pooled_width / pooled_height) % output_dim;
84 | int n = index / pooled_width / pooled_height / output_dim;
85 |
86 | // [start, end) interval for spatial sampling
87 | const T *offset_bottom_rois = bottom_rois + n * 5;
88 | int roi_batch_ind = offset_bottom_rois[0];
89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
93 |
94 | // Force too small ROIs to be 1x1
95 | T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
96 | T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
97 |
98 | // Compute w and h at bottom
99 | T bin_size_h = roi_height / static_cast(pooled_height);
100 | T bin_size_w = roi_width / static_cast(pooled_width);
101 |
102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
104 |
105 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
106 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
107 | int class_id = ctop / channels_each_class;
108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
110 |
111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
112 | wstart += trans_x * roi_width;
113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
114 | hstart += trans_y * roi_height;
115 |
116 | T sum = 0;
117 | int count = 0;
118 | int gw = floor(static_cast(pw) * group_size / pooled_width);
119 | int gh = floor(static_cast(ph) * group_size / pooled_height);
120 | gw = std::min(std::max(gw, 0), group_size - 1);
121 | gh = std::min(std::max(gh, 0), group_size - 1);
122 |
123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
124 | for (int ih = 0; ih < sample_per_part; ih++)
125 | {
126 | for (int iw = 0; iw < sample_per_part; iw++)
127 | {
128 | T w = wstart + iw * sub_bin_size_w;
129 | T h = hstart + ih * sub_bin_size_h;
130 | // bilinear interpolation
131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
132 | {
133 | continue;
134 | }
135 | w = std::min(std::max(w, T(0.)), width - T(1.));
136 | h = std::min(std::max(h, T(0.)), height - T(1.));
137 | int c = (ctop * group_size + gh) * group_size + gw;
138 | T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height);
139 | sum += val;
140 | count++;
141 | }
142 | }
143 | top_data[index] = count == 0 ? static_cast(0) : sum / count;
144 | top_count[index] = count;
145 | }
146 | }
147 |
148 | template
149 | void DeformablePSROIPoolBackwardAccKernelCpu(
150 | const int count,
151 | const T *top_diff,
152 | const T *top_count,
153 | const int num_rois,
154 | const T spatial_scale,
155 | const int channels,
156 | const int height, const int width,
157 | const int pooled_height, const int pooled_width,
158 | const int output_dim,
159 | T *bottom_data_diff, T *bottom_trans_diff,
160 | const T *bottom_data,
161 | const T *bottom_rois,
162 | const T *bottom_trans,
163 | const int no_trans,
164 | const T trans_std,
165 | const int sample_per_part,
166 | const int group_size,
167 | const int part_size,
168 | const int num_classes,
169 | const int channels_each_class)
170 | {
171 | for(int index = 0; index < count; index++)
172 | {
173 | // The output is in order (n, ctop, ph, pw)
174 | int pw = index % pooled_width;
175 | int ph = (index / pooled_width) % pooled_height;
176 | int ctop = (index / pooled_width / pooled_height) % output_dim;
177 | int n = index / pooled_width / pooled_height / output_dim;
178 |
179 | // [start, end) interval for spatial sampling
180 | const T *offset_bottom_rois = bottom_rois + n * 5;
181 | int roi_batch_ind = offset_bottom_rois[0];
182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
186 |
187 | // Force too small ROIs to be 1x1
188 | T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
189 | T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
190 |
191 | // Compute w and h at bottom
192 | T bin_size_h = roi_height / static_cast(pooled_height);
193 | T bin_size_w = roi_width / static_cast(pooled_width);
194 |
195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
197 |
198 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
199 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
200 | int class_id = ctop / channels_each_class;
201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
203 |
204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
205 | wstart += trans_x * roi_width;
206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
207 | hstart += trans_y * roi_height;
208 |
209 | if (top_count[index] <= 0)
210 | {
211 | continue;
212 | }
213 | T diff_val = top_diff[index] / top_count[index];
214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
216 | int gw = floor(static_cast(pw) * group_size / pooled_width);
217 | int gh = floor(static_cast(ph) * group_size / pooled_height);
218 | gw = std::min(std::max(gw, 0), group_size - 1);
219 | gh = std::min(std::max(gh, 0), group_size - 1);
220 |
221 | for (int ih = 0; ih < sample_per_part; ih++)
222 | {
223 | for (int iw = 0; iw < sample_per_part; iw++)
224 | {
225 | T w = wstart + iw * sub_bin_size_w;
226 | T h = hstart + ih * sub_bin_size_h;
227 | // bilinear interpolation
228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
229 | {
230 | continue;
231 | }
232 | w = std::min(std::max(w, T(0.)), width - T(1.));
233 | h = std::min(std::max(h, T(0.)), height - T(1.));
234 | int c = (ctop * group_size + gh) * group_size + gw;
235 | // backward on feature
236 | int x0 = floor(w);
237 | int x1 = ceil(w);
238 | int y0 = floor(h);
239 | int y1 = ceil(h);
240 | T dist_x = w - x0, dist_y = h - y0;
241 | T q00 = (1 - dist_x) * (1 - dist_y);
242 | T q01 = (1 - dist_x) * dist_y;
243 | T q10 = dist_x * (1 - dist_y);
244 | T q11 = dist_x * dist_y;
245 | int bottom_index_base = c * height * width;
246 | /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/
250 | *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val;
251 | *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val;
252 | *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val;
253 | *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val;
254 |
255 |
256 | if (no_trans)
257 | {
258 | continue;
259 | }
260 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
261 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
262 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
263 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
264 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
265 | diff_x *= roi_width;
266 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
267 | diff_y *= roi_height;
268 |
269 | /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
270 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/
271 | *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x;
272 | *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y;
273 | }
274 | }
275 | }
276 | }
277 |
278 | std::tuple
279 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
280 | const at::Tensor &bbox,
281 | const at::Tensor &trans,
282 | const int no_trans,
283 | const float spatial_scale,
284 | const int output_dim,
285 | const int group_size,
286 | const int pooled_size,
287 | const int part_size,
288 | const int sample_per_part,
289 | const float trans_std)
290 | {
291 | /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
292 | AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor");
293 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");*/
294 |
295 | // const int batch = input.size(0);
296 | const int channels = input.size(1);
297 | const int height = input.size(2);
298 | const int width = input.size(3);
299 | const int channels_trans = no_trans ? 2 : trans.size(1);
300 | const int num_bbox = bbox.size(0);
301 |
302 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
303 | auto pooled_height = pooled_size;
304 | auto pooled_width = pooled_size;
305 |
306 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
307 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
308 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
309 |
310 | const int num_classes = no_trans ? 1 : channels_trans / 2;
311 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
312 |
313 | //cudaStream_t stream = at::cuda::getCurrentCUDAStream();
314 |
315 | if (out.numel() == 0)
316 | {
317 | //THCudaCheck(cudaGetLastError());
318 | return std::make_tuple(out, top_count);
319 | }
320 |
321 | /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
322 | dim3 block(512);*/
323 |
324 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cpu_forward", [&] {
325 | DeformablePSROIPoolForwardKernelCpu(
326 | out_size,
327 | input.contiguous().data_ptr(),
328 | spatial_scale,
329 | channels,
330 | height, width,
331 | pooled_height,
332 | pooled_width,
333 | bbox.contiguous().data_ptr(),
334 | trans.contiguous().data_ptr(),
335 | no_trans,
336 | trans_std,
337 | sample_per_part,
338 | output_dim,
339 | group_size,
340 | part_size,
341 | num_classes,
342 | channels_each_class,
343 | out.data_ptr(),
344 | top_count.data_ptr());
345 | });
346 | //THCudaCheck(cudaGetLastError());
347 | return std::make_tuple(out, top_count);
348 | }
349 |
350 | std::tuple
351 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
352 | const at::Tensor &input,
353 | const at::Tensor &bbox,
354 | const at::Tensor &trans,
355 | const at::Tensor &top_count,
356 | const int no_trans,
357 | const float spatial_scale,
358 | const int output_dim,
359 | const int group_size,
360 | const int pooled_size,
361 | const int part_size,
362 | const int sample_per_part,
363 | const float trans_std)
364 | {
365 | /*AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor");
366 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
367 | AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor");
368 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");
369 | AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor");*/
370 |
371 | const int batch = input.size(0);
372 | const int channels = input.size(1);
373 | const int height = input.size(2);
374 | const int width = input.size(3);
375 | const int channels_trans = no_trans ? 2 : trans.size(1);
376 | const int num_bbox = bbox.size(0);
377 |
378 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
379 | auto pooled_height = pooled_size;
380 | auto pooled_width = pooled_size;
381 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
382 | const int num_classes = no_trans ? 1 : channels_trans / 2;
383 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
384 |
385 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
386 | auto trans_grad = at::zeros_like(trans);
387 |
388 | if (input_grad.numel() == 0)
389 | {
390 | //THCudaCheck(cudaGetLastError());
391 | return std::make_tuple(input_grad, trans_grad);
392 | }
393 |
394 | /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
395 | dim3 block(512);
396 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/
397 |
398 | AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cpu_backward", [&] {
399 | DeformablePSROIPoolBackwardAccKernelCpu(
400 | out_size,
401 | out_grad.contiguous().data_ptr(),
402 | top_count.contiguous().data_ptr(),
403 | num_bbox,
404 | spatial_scale,
405 | channels,
406 | height,
407 | width,
408 | pooled_height,
409 | pooled_width,
410 | output_dim,
411 | input_grad.contiguous().data_ptr(),
412 | trans_grad.contiguous().data_ptr(),
413 | input.contiguous().data_ptr(),
414 | bbox.contiguous().data_ptr(),
415 | trans.contiguous().data_ptr(),
416 | no_trans,
417 | trans_std,
418 | sample_per_part,
419 | group_size,
420 | part_size,
421 | num_classes,
422 | channels_each_class);
423 | });
424 | //THCudaCheck(cudaGetLastError());
425 | return std::make_tuple(input_grad, trans_grad);
426 | }
--------------------------------------------------------------------------------
/src/cpu/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cpu_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cpu_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/src/cuda/dcn_v2_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "cuda/dcn_v2_im2col_cuda.h"
3 |
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | #include
10 |
11 | THCState *state = at::globalContext().lazyInitCUDA();
12 |
13 | // author: Charles Shang
14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
15 |
16 | // [batch gemm]
17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
18 |
19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
20 | float **columns_b, const float **ones_b,
21 | const float **weight_b, const float **bias_b,
22 | float *input, float *output,
23 | float *columns, float *ones,
24 | float *weight, float *bias,
25 | const int input_stride, const int output_stride,
26 | const int columns_stride, const int ones_stride,
27 | const int num_batches)
28 | {
29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
30 | if (idx < num_batches)
31 | {
32 | input_b[idx] = input + idx * input_stride;
33 | output_b[idx] = output + idx * output_stride;
34 | columns_b[idx] = columns + idx * columns_stride;
35 | ones_b[idx] = ones + idx * ones_stride;
36 | // share weights and bias within a Mini-Batch
37 | weight_b[idx] = weight;
38 | bias_b[idx] = bias;
39 | }
40 | }
41 |
42 | at::Tensor
43 | dcn_v2_cuda_forward(const at::Tensor &input,
44 | const at::Tensor &weight,
45 | const at::Tensor &bias,
46 | const at::Tensor &offset,
47 | const at::Tensor &mask,
48 | const int kernel_h,
49 | const int kernel_w,
50 | const int stride_h,
51 | const int stride_w,
52 | const int pad_h,
53 | const int pad_w,
54 | const int dilation_h,
55 | const int dilation_w,
56 | const int deformable_group)
57 | {
58 | using scalar_t = float;
59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
60 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
61 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor");
62 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor");
63 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor");
64 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");
65 |
66 | const int batch = input.size(0);
67 | const int channels = input.size(1);
68 | const int height = input.size(2);
69 | const int width = input.size(3);
70 |
71 | const int channels_out = weight.size(0);
72 | const int channels_kernel = weight.size(1);
73 | const int kernel_h_ = weight.size(2);
74 | const int kernel_w_ = weight.size(3);
75 |
76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
77 | // printf("Channels: %d %d\n", channels, channels_kernel);
78 | // printf("Channels: %d %d\n", channels_out, channels_kernel);
79 |
80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
82 |
83 | AT_ASSERTM(channels == channels_kernel,
84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
85 |
86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
88 |
89 | auto ones = at::ones({batch, height_out, width_out}, input.options());
90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
92 |
93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing
94 | // when batch size is large.
95 | // launch batch threads
96 | int matrices_size = batch * sizeof(float *);
97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size));
98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size));
99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size));
100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size));
101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size));
102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size));
103 |
104 | const int block = 128;
105 | const int grid = (batch + block - 1) / block;
106 |
107 | createBatchGemmBuffer<<>>(
108 | input_b, output_b,
109 | columns_b, ones_b,
110 | weight_b, bias_b,
111 | input.data_ptr(),
112 | output.data_ptr(),
113 | columns.data_ptr(),
114 | ones.data_ptr(),
115 | weight.data_ptr(),
116 | bias.data_ptr(),
117 | channels * width * height,
118 | channels_out * width_out * height_out,
119 | channels * kernel_h * kernel_w * height_out * width_out,
120 | height_out * width_out,
121 | batch);
122 |
123 | long m_ = channels_out;
124 | long n_ = height_out * width_out;
125 | long k_ = 1;
126 | THCudaBlas_SgemmBatched(state,
127 | 't',
128 | 'n',
129 | n_,
130 | m_,
131 | k_,
132 | 1.0f,
133 | ones_b, k_,
134 | bias_b, k_,
135 | 0.0f,
136 | output_b, n_,
137 | batch);
138 |
139 | modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
140 | input.data_ptr(),
141 | offset.data_ptr(),
142 | mask.data_ptr(),
143 | batch, channels, height, width,
144 | height_out, width_out, kernel_h, kernel_w,
145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
146 | deformable_group,
147 | columns.data_ptr());
148 |
149 | long m = channels_out;
150 | long n = height_out * width_out;
151 | long k = channels * kernel_h * kernel_w;
152 | THCudaBlas_SgemmBatched(state,
153 | 'n',
154 | 'n',
155 | n,
156 | m,
157 | k,
158 | 1.0f,
159 | (const float **)columns_b, n,
160 | weight_b, k,
161 | 1.0f,
162 | output_b, n,
163 | batch);
164 |
165 | THCudaFree(state, input_b);
166 | THCudaFree(state, output_b);
167 | THCudaFree(state, columns_b);
168 | THCudaFree(state, ones_b);
169 | THCudaFree(state, weight_b);
170 | THCudaFree(state, bias_b);
171 | return output;
172 | }
173 |
174 | __global__ void createBatchGemmBufferBackward(
175 | float **grad_output_b,
176 | float **columns_b,
177 | float **ones_b,
178 | float **weight_b,
179 | float **grad_weight_b,
180 | float **grad_bias_b,
181 | float *grad_output,
182 | float *columns,
183 | float *ones,
184 | float *weight,
185 | float *grad_weight,
186 | float *grad_bias,
187 | const int grad_output_stride,
188 | const int columns_stride,
189 | const int ones_stride,
190 | const int num_batches)
191 | {
192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
193 | if (idx < num_batches)
194 | {
195 | grad_output_b[idx] = grad_output + idx * grad_output_stride;
196 | columns_b[idx] = columns + idx * columns_stride;
197 | ones_b[idx] = ones + idx * ones_stride;
198 |
199 | // share weights and bias within a Mini-Batch
200 | weight_b[idx] = weight;
201 | grad_weight_b[idx] = grad_weight;
202 | grad_bias_b[idx] = grad_bias;
203 | }
204 | }
205 |
206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input,
207 | const at::Tensor &weight,
208 | const at::Tensor &bias,
209 | const at::Tensor &offset,
210 | const at::Tensor &mask,
211 | const at::Tensor &grad_output,
212 | int kernel_h, int kernel_w,
213 | int stride_h, int stride_w,
214 | int pad_h, int pad_w,
215 | int dilation_h, int dilation_w,
216 | int deformable_group)
217 | {
218 |
219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
221 |
222 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
223 | AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor");
224 | AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor");
225 | AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor");
226 | AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");
227 |
228 | const int batch = input.size(0);
229 | const int channels = input.size(1);
230 | const int height = input.size(2);
231 | const int width = input.size(3);
232 |
233 | const int channels_out = weight.size(0);
234 | const int channels_kernel = weight.size(1);
235 | const int kernel_h_ = weight.size(2);
236 | const int kernel_w_ = weight.size(3);
237 |
238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
240 |
241 | AT_ASSERTM(channels == channels_kernel,
242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
243 |
244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
246 |
247 | auto ones = at::ones({height_out, width_out}, input.options());
248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
250 |
251 | auto grad_input = at::zeros_like(input);
252 | auto grad_weight = at::zeros_like(weight);
253 | auto grad_bias = at::zeros_like(bias);
254 | auto grad_offset = at::zeros_like(offset);
255 | auto grad_mask = at::zeros_like(mask);
256 |
257 | using scalar_t = float;
258 |
259 | for (int b = 0; b < batch; b++)
260 | {
261 | auto input_n = input.select(0, b);
262 | auto offset_n = offset.select(0, b);
263 | auto mask_n = mask.select(0, b);
264 | auto grad_output_n = grad_output.select(0, b);
265 | auto grad_input_n = grad_input.select(0, b);
266 | auto grad_offset_n = grad_offset.select(0, b);
267 | auto grad_mask_n = grad_mask.select(0, b);
268 |
269 | long m = channels * kernel_h * kernel_w;
270 | long n = height_out * width_out;
271 | long k = channels_out;
272 |
273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
274 | grad_output_n.data_ptr(), n,
275 | weight.data_ptr(), m, 0.0f,
276 | columns.data_ptr(), n);
277 |
278 | // gradient w.r.t. input coordinate data
279 | modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),
280 | columns.data_ptr(),
281 | input_n.data_ptr(),
282 | offset_n.data_ptr(),
283 | mask_n.data_ptr(),
284 | 1, channels, height, width,
285 | height_out, width_out, kernel_h, kernel_w,
286 | pad_h, pad_w, stride_h, stride_w,
287 | dilation_h, dilation_w, deformable_group,
288 | grad_offset_n.data_ptr(),
289 | grad_mask_n.data_ptr());
290 | // gradient w.r.t. input data
291 | modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),
292 | columns.data_ptr(),
293 | offset_n.data_ptr(),
294 | mask_n.data_ptr(),
295 | 1, channels, height, width,
296 | height_out, width_out, kernel_h, kernel_w,
297 | pad_h, pad_w, stride_h, stride_w,
298 | dilation_h, dilation_w, deformable_group,
299 | grad_input_n.data_ptr());
300 |
301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group
302 | modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
303 | input_n.data_ptr(),
304 | offset_n.data_ptr(),
305 | mask_n.data_ptr(),
306 | 1, channels, height, width,
307 | height_out, width_out, kernel_h, kernel_w,
308 | pad_h, pad_w, stride_h, stride_w,
309 | dilation_h, dilation_w, deformable_group,
310 | columns.data_ptr());
311 |
312 | long m_ = channels_out;
313 | long n_ = channels * kernel_h * kernel_w;
314 | long k_ = height_out * width_out;
315 |
316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
317 | columns.data_ptr(), k_,
318 | grad_output_n.data_ptr(), k_, 1.0f,
319 | grad_weight.data_ptr(), n_);
320 |
321 | // gradient w.r.t. bias
322 | // long m_ = channels_out;
323 | // long k__ = height_out * width_out;
324 | // THCudaBlas_Sgemm(state,
325 | // 't', 'n',
326 | // k_, m_, 1, 1.0f,
327 | // grad_output_n.data_ptr(), k_,
328 | // ones.data_ptr(), 1, 1.0f,
329 | // grad_bias.data_ptr(), 1);
330 | THCudaBlas_Sgemm(state,
331 | 'N', 'N', 1, m_, k_, 1.0f,
332 | ones.data_ptr(), 1,
333 | grad_output_n.data_ptr(), k_,
334 | 1.0f,
335 | grad_bias.data_ptr(), 1);
336 | }
337 |
338 | return {
339 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias
340 | };
341 | }
342 |
--------------------------------------------------------------------------------
/src/cuda/dcn_v2_im2col_cuda.cu:
--------------------------------------------------------------------------------
1 | #include "dcn_v2_im2col_cuda.h"
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #define CUDA_KERNEL_LOOP(i, n) \
14 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
15 | i < (n); \
16 | i += blockDim.x * gridDim.x)
17 |
18 | const int CUDA_NUM_THREADS = 1024;
19 | inline int GET_BLOCKS(const int N)
20 | {
21 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
22 | }
23 |
24 |
25 | __device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width,
26 | const int height, const int width, float h, float w)
27 | {
28 | int h_low = floor(h);
29 | int w_low = floor(w);
30 | int h_high = h_low + 1;
31 | int w_high = w_low + 1;
32 |
33 | float lh = h - h_low;
34 | float lw = w - w_low;
35 | float hh = 1 - lh, hw = 1 - lw;
36 |
37 | float v1 = 0;
38 | if (h_low >= 0 && w_low >= 0)
39 | v1 = bottom_data[h_low * data_width + w_low];
40 | float v2 = 0;
41 | if (h_low >= 0 && w_high <= width - 1)
42 | v2 = bottom_data[h_low * data_width + w_high];
43 | float v3 = 0;
44 | if (h_high <= height - 1 && w_low >= 0)
45 | v3 = bottom_data[h_high * data_width + w_low];
46 | float v4 = 0;
47 | if (h_high <= height - 1 && w_high <= width - 1)
48 | v4 = bottom_data[h_high * data_width + w_high];
49 |
50 | float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
51 |
52 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
53 | return val;
54 | }
55 |
56 | __device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w,
57 | const int h, const int w, const int height, const int width)
58 | {
59 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
60 | {
61 | //empty
62 | return 0;
63 | }
64 |
65 | int argmax_h_low = floor(argmax_h);
66 | int argmax_w_low = floor(argmax_w);
67 | int argmax_h_high = argmax_h_low + 1;
68 | int argmax_w_high = argmax_w_low + 1;
69 |
70 | float weight = 0;
71 | if (h == argmax_h_low && w == argmax_w_low)
72 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
73 | if (h == argmax_h_low && w == argmax_w_high)
74 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
75 | if (h == argmax_h_high && w == argmax_w_low)
76 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
77 | if (h == argmax_h_high && w == argmax_w_high)
78 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
79 | return weight;
80 | }
81 |
82 | __device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w,
83 | const int height, const int width, const float *im_data,
84 | const int data_width, const int bp_dir)
85 | {
86 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
87 | {
88 | //empty
89 | return 0;
90 | }
91 |
92 | int argmax_h_low = floor(argmax_h);
93 | int argmax_w_low = floor(argmax_w);
94 | int argmax_h_high = argmax_h_low + 1;
95 | int argmax_w_high = argmax_w_low + 1;
96 |
97 | float weight = 0;
98 |
99 | if (bp_dir == 0)
100 | {
101 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
102 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
103 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
104 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
105 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
106 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
107 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
108 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
109 | }
110 | else if (bp_dir == 1)
111 | {
112 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
113 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
114 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
115 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
116 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
117 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
118 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
119 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
120 | }
121 |
122 | return weight;
123 | }
124 |
125 | __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
126 | const float *data_im, const float *data_offset, const float *data_mask,
127 | const int height, const int width, const int kernel_h, const int kernel_w,
128 | const int pad_h, const int pad_w,
129 | const int stride_h, const int stride_w,
130 | const int dilation_h, const int dilation_w,
131 | const int channel_per_deformable_group,
132 | const int batch_size, const int num_channels, const int deformable_group,
133 | const int height_col, const int width_col,
134 | float *data_col)
135 | {
136 | // launch channels * batch_size * height_col * width_col cores
137 | CUDA_KERNEL_LOOP(index, n)
138 | {
139 | // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
140 | // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
141 |
142 | // index index of output matrix
143 | const int w_col = index % width_col;
144 | const int h_col = (index / width_col) % height_col;
145 | // const int b_col = (index / width_col / height_col) % batch_size;
146 | const int b_col = (index / width_col / height_col / num_channels) % batch_size;
147 | // const int c_im = (index / width_col / height_col) / batch_size;
148 | const int c_im = (index / width_col / height_col) % num_channels;
149 | // const int c_col = c_im * kernel_h * kernel_w;
150 | const int c_col = c_im * kernel_h * kernel_w;
151 |
152 | // compute deformable group index
153 | const int deformable_group_index = c_im / channel_per_deformable_group;
154 |
155 | const int h_in = h_col * stride_h - pad_h;
156 | const int w_in = w_col * stride_w - pad_w;
157 |
158 | // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
159 | float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;
160 | //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
161 | const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
162 | const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
163 |
164 | const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
165 |
166 | for (int i = 0; i < kernel_h; ++i)
167 | {
168 | for (int j = 0; j < kernel_w; ++j)
169 | {
170 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
171 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
172 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
173 | const float offset_h = data_offset_ptr[data_offset_h_ptr];
174 | const float offset_w = data_offset_ptr[data_offset_w_ptr];
175 | const float mask = data_mask_ptr[data_mask_hw_ptr];
176 | float val = static_cast(0);
177 | const float h_im = h_in + i * dilation_h + offset_h;
178 | const float w_im = w_in + j * dilation_w + offset_w;
179 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
180 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
181 | {
182 | //const float map_h = i * dilation_h + offset_h;
183 | //const float map_w = j * dilation_w + offset_w;
184 | //const int cur_height = height - h_in;
185 | //const int cur_width = width - w_in;
186 | //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
187 | val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im);
188 | }
189 | *data_col_ptr = val * mask;
190 | // data_col_ptr += batch_size * height_col * width_col;
191 | data_col_ptr += height_col * width_col;
192 | }
193 | }
194 | }
195 | }
196 |
197 | __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
198 | const float *data_col, const float *data_offset, const float *data_mask,
199 | const int channels, const int height, const int width,
200 | const int kernel_h, const int kernel_w,
201 | const int pad_h, const int pad_w,
202 | const int stride_h, const int stride_w,
203 | const int dilation_h, const int dilation_w,
204 | const int channel_per_deformable_group,
205 | const int batch_size, const int deformable_group,
206 | const int height_col, const int width_col,
207 | float *grad_im)
208 | {
209 | CUDA_KERNEL_LOOP(index, n)
210 | {
211 | const int j = (index / width_col / height_col / batch_size) % kernel_w;
212 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
213 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
214 | // compute the start and end of the output
215 |
216 | const int deformable_group_index = c / channel_per_deformable_group;
217 |
218 | int w_out = index % width_col;
219 | int h_out = (index / width_col) % height_col;
220 | int b = (index / width_col / height_col) % batch_size;
221 | int w_in = w_out * stride_w - pad_w;
222 | int h_in = h_out * stride_h - pad_h;
223 |
224 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
225 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
226 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
227 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
228 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
229 | const float offset_h = data_offset_ptr[data_offset_h_ptr];
230 | const float offset_w = data_offset_ptr[data_offset_w_ptr];
231 | const float mask = data_mask_ptr[data_mask_hw_ptr];
232 | const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
233 | const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
234 |
235 | const float cur_top_grad = data_col[index] * mask;
236 | const int cur_h = (int)cur_inv_h_data;
237 | const int cur_w = (int)cur_inv_w_data;
238 | for (int dy = -2; dy <= 2; dy++)
239 | {
240 | for (int dx = -2; dx <= 2; dx++)
241 | {
242 | if (cur_h + dy >= 0 && cur_h + dy < height &&
243 | cur_w + dx >= 0 && cur_w + dx < width &&
244 | abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
245 | abs(cur_inv_w_data - (cur_w + dx)) < 1)
246 | {
247 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
248 | float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
249 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
250 | }
251 | }
252 | }
253 | }
254 | }
255 |
256 | __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
257 | const float *data_col, const float *data_im,
258 | const float *data_offset, const float *data_mask,
259 | const int channels, const int height, const int width,
260 | const int kernel_h, const int kernel_w,
261 | const int pad_h, const int pad_w,
262 | const int stride_h, const int stride_w,
263 | const int dilation_h, const int dilation_w,
264 | const int channel_per_deformable_group,
265 | const int batch_size, const int offset_channels, const int deformable_group,
266 | const int height_col, const int width_col,
267 | float *grad_offset, float *grad_mask)
268 | {
269 | CUDA_KERNEL_LOOP(index, n)
270 | {
271 | float val = 0, mval = 0;
272 | int w = index % width_col;
273 | int h = (index / width_col) % height_col;
274 | int c = (index / width_col / height_col) % offset_channels;
275 | int b = (index / width_col / height_col) / offset_channels;
276 | // compute the start and end of the output
277 |
278 | const int deformable_group_index = c / (2 * kernel_h * kernel_w);
279 | const int col_step = kernel_h * kernel_w;
280 | int cnt = 0;
281 | const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
282 | const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
283 | const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
284 | const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
285 |
286 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
287 |
288 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
289 | {
290 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
291 | const int bp_dir = offset_c % 2;
292 |
293 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
294 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
295 | int w_out = col_pos % width_col;
296 | int h_out = (col_pos / width_col) % height_col;
297 | int w_in = w_out * stride_w - pad_w;
298 | int h_in = h_out * stride_h - pad_h;
299 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
300 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
301 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
302 | const float offset_h = data_offset_ptr[data_offset_h_ptr];
303 | const float offset_w = data_offset_ptr[data_offset_w_ptr];
304 | const float mask = data_mask_ptr[data_mask_hw_ptr];
305 | float inv_h = h_in + i * dilation_h + offset_h;
306 | float inv_w = w_in + j * dilation_w + offset_w;
307 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
308 | {
309 | inv_h = inv_w = -2;
310 | }
311 | else
312 | {
313 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
314 | }
315 | const float weight = dmcn_get_coordinate_weight_cuda(
316 | inv_h, inv_w,
317 | height, width, data_im_ptr + cnt * height * width, width, bp_dir);
318 | val += weight * data_col_ptr[col_pos] * mask;
319 | cnt += 1;
320 | }
321 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
322 | grad_offset[index] = val;
323 | if (offset_c % 2 == 0)
324 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
325 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
326 | }
327 | }
328 |
329 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
330 | const float* data_im, const float* data_offset, const float* data_mask,
331 | const int batch_size, const int channels, const int height_im, const int width_im,
332 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
333 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
334 | const int dilation_h, const int dilation_w,
335 | const int deformable_group, float* data_col) {
336 | // num_axes should be smaller than block size
337 | const int channel_per_deformable_group = channels / deformable_group;
338 | const int num_kernels = channels * batch_size * height_col * width_col;
339 | modulated_deformable_im2col_gpu_kernel
340 | <<>>(
342 | num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
343 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
344 | batch_size, channels, deformable_group, height_col, width_col, data_col);
345 |
346 | cudaError_t err = cudaGetLastError();
347 | if (err != cudaSuccess)
348 | {
349 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
350 | }
351 |
352 | }
353 |
354 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
355 | const float* data_col, const float* data_offset, const float* data_mask,
356 | const int batch_size, const int channels, const int height_im, const int width_im,
357 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
358 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
359 | const int dilation_h, const int dilation_w,
360 | const int deformable_group, float* grad_im){
361 |
362 | const int channel_per_deformable_group = channels / deformable_group;
363 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
364 | modulated_deformable_col2im_gpu_kernel
365 | <<>>(
367 | num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
368 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
369 | dilation_h, dilation_w, channel_per_deformable_group,
370 | batch_size, deformable_group, height_col, width_col, grad_im);
371 | cudaError_t err = cudaGetLastError();
372 | if (err != cudaSuccess)
373 | {
374 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
375 | }
376 |
377 | }
378 |
379 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
380 | const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
381 | const int batch_size, const int channels, const int height_im, const int width_im,
382 | const int height_col, const int width_col, const int kernel_h, const int kernel_w,
383 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
384 | const int dilation_h, const int dilation_w,
385 | const int deformable_group,
386 | float* grad_offset, float* grad_mask) {
387 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
388 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
389 | modulated_deformable_col2im_coord_gpu_kernel
390 | <<>>(
392 | num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
393 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
394 | dilation_h, dilation_w, channel_per_deformable_group,
395 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
396 | grad_offset, grad_mask);
397 | cudaError_t err = cudaGetLastError();
398 | if (err != cudaSuccess)
399 | {
400 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
401 | }
402 | }
--------------------------------------------------------------------------------
/src/cuda/dcn_v2_im2col_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 |
64 | #ifndef DCN_V2_IM2COL_CUDA
65 | #define DCN_V2_IM2COL_CUDA
66 |
67 | #ifdef __cplusplus
68 | extern "C"
69 | {
70 | #endif
71 |
72 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
73 | const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
81 | const float *data_col, const float *data_offset, const float *data_mask,
82 | const int batch_size, const int channels, const int height_im, const int width_im,
83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
85 | const int dilation_h, const int dilation_w,
86 | const int deformable_group, float *grad_im);
87 |
88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
90 | const int batch_size, const int channels, const int height_im, const int width_im,
91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
93 | const int dilation_h, const int dilation_w,
94 | const int deformable_group,
95 | float *grad_offset, float *grad_mask);
96 |
97 | #ifdef __cplusplus
98 | }
99 | #endif
100 |
101 | #endif
--------------------------------------------------------------------------------
/src/cuda/dcn_v2_psroi_pooling_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #define CUDA_KERNEL_LOOP(i, n) \
23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
24 | i < (n); \
25 | i += blockDim.x * gridDim.x)
26 |
27 | const int CUDA_NUM_THREADS = 1024;
28 | inline int GET_BLOCKS(const int N)
29 | {
30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
31 | }
32 |
33 | template
34 | __device__ T bilinear_interp_cuda(
35 | const T *data,
36 | const T x,
37 | const T y,
38 | const int width,
39 | const int height)
40 | {
41 | int x1 = floor(x);
42 | int x2 = ceil(x);
43 | int y1 = floor(y);
44 | int y2 = ceil(y);
45 | T dist_x = static_cast(x - x1);
46 | T dist_y = static_cast(y - y1);
47 | T value11 = data[y1 * width + x1];
48 | T value12 = data[y2 * width + x1];
49 | T value21 = data[y1 * width + x2];
50 | T value22 = data[y2 * width + x2];
51 | T value = (1 - dist_x) * (1 - dist_y) * value11 +
52 | (1 - dist_x) * dist_y * value12 +
53 | dist_x * (1 - dist_y) * value21 +
54 | dist_x * dist_y * value22;
55 | return value;
56 | }
57 |
58 | template
59 | __global__ void DeformablePSROIPoolForwardKernelCuda(
60 | const int count,
61 | const T *bottom_data,
62 | const T spatial_scale,
63 | const int channels,
64 | const int height, const int width,
65 | const int pooled_height, const int pooled_width,
66 | const T *bottom_rois, const T *bottom_trans,
67 | const int no_trans,
68 | const T trans_std,
69 | const int sample_per_part,
70 | const int output_dim,
71 | const int group_size,
72 | const int part_size,
73 | const int num_classes,
74 | const int channels_each_class,
75 | T *top_data,
76 | T *top_count)
77 | {
78 | CUDA_KERNEL_LOOP(index, count)
79 | {
80 | // The output is in order (n, ctop, ph, pw)
81 | int pw = index % pooled_width;
82 | int ph = (index / pooled_width) % pooled_height;
83 | int ctop = (index / pooled_width / pooled_height) % output_dim;
84 | int n = index / pooled_width / pooled_height / output_dim;
85 |
86 | // [start, end) interval for spatial sampling
87 | const T *offset_bottom_rois = bottom_rois + n * 5;
88 | int roi_batch_ind = offset_bottom_rois[0];
89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
93 |
94 | // Force too small ROIs to be 1x1
95 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
96 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
97 |
98 | // Compute w and h at bottom
99 | T bin_size_h = roi_height / static_cast(pooled_height);
100 | T bin_size_w = roi_width / static_cast(pooled_width);
101 |
102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
104 |
105 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
106 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
107 | int class_id = ctop / channels_each_class;
108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
110 |
111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
112 | wstart += trans_x * roi_width;
113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
114 | hstart += trans_y * roi_height;
115 |
116 | T sum = 0;
117 | int count = 0;
118 | int gw = floor(static_cast(pw) * group_size / pooled_width);
119 | int gh = floor(static_cast(ph) * group_size / pooled_height);
120 | gw = min(max(gw, 0), group_size - 1);
121 | gh = min(max(gh, 0), group_size - 1);
122 |
123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
124 | for (int ih = 0; ih < sample_per_part; ih++)
125 | {
126 | for (int iw = 0; iw < sample_per_part; iw++)
127 | {
128 | T w = wstart + iw * sub_bin_size_w;
129 | T h = hstart + ih * sub_bin_size_h;
130 | // bilinear interpolation
131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
132 | {
133 | continue;
134 | }
135 | w = min(max(w, 0.), width - 1.);
136 | h = min(max(h, 0.), height - 1.);
137 | int c = (ctop * group_size + gh) * group_size + gw;
138 | T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height);
139 | sum += val;
140 | count++;
141 | }
142 | }
143 | top_data[index] = count == 0 ? static_cast(0) : sum / count;
144 | top_count[index] = count;
145 | }
146 | }
147 |
148 | template
149 | __global__ void DeformablePSROIPoolBackwardAccKernelCuda(
150 | const int count,
151 | const T *top_diff,
152 | const T *top_count,
153 | const int num_rois,
154 | const T spatial_scale,
155 | const int channels,
156 | const int height, const int width,
157 | const int pooled_height, const int pooled_width,
158 | const int output_dim,
159 | T *bottom_data_diff, T *bottom_trans_diff,
160 | const T *bottom_data,
161 | const T *bottom_rois,
162 | const T *bottom_trans,
163 | const int no_trans,
164 | const T trans_std,
165 | const int sample_per_part,
166 | const int group_size,
167 | const int part_size,
168 | const int num_classes,
169 | const int channels_each_class)
170 | {
171 | CUDA_KERNEL_LOOP(index, count)
172 | {
173 | // The output is in order (n, ctop, ph, pw)
174 | int pw = index % pooled_width;
175 | int ph = (index / pooled_width) % pooled_height;
176 | int ctop = (index / pooled_width / pooled_height) % output_dim;
177 | int n = index / pooled_width / pooled_height / output_dim;
178 |
179 | // [start, end) interval for spatial sampling
180 | const T *offset_bottom_rois = bottom_rois + n * 5;
181 | int roi_batch_ind = offset_bottom_rois[0];
182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
186 |
187 | // Force too small ROIs to be 1x1
188 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
189 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
190 |
191 | // Compute w and h at bottom
192 | T bin_size_h = roi_height / static_cast(pooled_height);
193 | T bin_size_w = roi_width / static_cast(pooled_width);
194 |
195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
197 |
198 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
199 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
200 | int class_id = ctop / channels_each_class;
201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
203 |
204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
205 | wstart += trans_x * roi_width;
206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
207 | hstart += trans_y * roi_height;
208 |
209 | if (top_count[index] <= 0)
210 | {
211 | continue;
212 | }
213 | T diff_val = top_diff[index] / top_count[index];
214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
216 | int gw = floor(static_cast(pw) * group_size / pooled_width);
217 | int gh = floor(static_cast(ph) * group_size / pooled_height);
218 | gw = min(max(gw, 0), group_size - 1);
219 | gh = min(max(gh, 0), group_size - 1);
220 |
221 | for (int ih = 0; ih < sample_per_part; ih++)
222 | {
223 | for (int iw = 0; iw < sample_per_part; iw++)
224 | {
225 | T w = wstart + iw * sub_bin_size_w;
226 | T h = hstart + ih * sub_bin_size_h;
227 | // bilinear interpolation
228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
229 | {
230 | continue;
231 | }
232 | w = min(max(w, 0.), width - 1.);
233 | h = min(max(h, 0.), height - 1.);
234 | int c = (ctop * group_size + gh) * group_size + gw;
235 | // backward on feature
236 | int x0 = floor(w);
237 | int x1 = ceil(w);
238 | int y0 = floor(h);
239 | int y1 = ceil(h);
240 | T dist_x = w - x0, dist_y = h - y0;
241 | T q00 = (1 - dist_x) * (1 - dist_y);
242 | T q01 = (1 - dist_x) * dist_y;
243 | T q10 = dist_x * (1 - dist_y);
244 | T q11 = dist_x * dist_y;
245 | int bottom_index_base = c * height * width;
246 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
250 |
251 | if (no_trans)
252 | {
253 | continue;
254 | }
255 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
256 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
257 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
258 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
259 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
260 | diff_x *= roi_width;
261 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
262 | diff_y *= roi_height;
263 |
264 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
265 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
266 | }
267 | }
268 | }
269 | }
270 |
271 | std::tuple
272 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
273 | const at::Tensor &bbox,
274 | const at::Tensor &trans,
275 | const int no_trans,
276 | const float spatial_scale,
277 | const int output_dim,
278 | const int group_size,
279 | const int pooled_size,
280 | const int part_size,
281 | const int sample_per_part,
282 | const float trans_std)
283 | {
284 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
285 | AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor");
286 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");
287 |
288 | const int batch = input.size(0);
289 | const int channels = input.size(1);
290 | const int height = input.size(2);
291 | const int width = input.size(3);
292 | const int channels_trans = no_trans ? 2 : trans.size(1);
293 | const int num_bbox = bbox.size(0);
294 |
295 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
296 | auto pooled_height = pooled_size;
297 | auto pooled_width = pooled_size;
298 |
299 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
300 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
301 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
302 |
303 | const int num_classes = no_trans ? 1 : channels_trans / 2;
304 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
305 |
306 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
307 |
308 | if (out.numel() == 0)
309 | {
310 | THCudaCheck(cudaGetLastError());
311 | return std::make_tuple(out, top_count);
312 | }
313 |
314 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
315 | dim3 block(512);
316 |
317 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cuda_forward", [&] {
318 | DeformablePSROIPoolForwardKernelCuda<<>>(
319 | out_size,
320 | input.contiguous().data_ptr(),
321 | spatial_scale,
322 | channels,
323 | height, width,
324 | pooled_height,
325 | pooled_width,
326 | bbox.contiguous().data_ptr(),
327 | trans.contiguous().data_ptr(),
328 | no_trans,
329 | trans_std,
330 | sample_per_part,
331 | output_dim,
332 | group_size,
333 | part_size,
334 | num_classes,
335 | channels_each_class,
336 | out.data_ptr(),
337 | top_count.data_ptr());
338 | });
339 | THCudaCheck(cudaGetLastError());
340 | return std::make_tuple(out, top_count);
341 | }
342 |
343 | std::tuple
344 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
345 | const at::Tensor &input,
346 | const at::Tensor &bbox,
347 | const at::Tensor &trans,
348 | const at::Tensor &top_count,
349 | const int no_trans,
350 | const float spatial_scale,
351 | const int output_dim,
352 | const int group_size,
353 | const int pooled_size,
354 | const int part_size,
355 | const int sample_per_part,
356 | const float trans_std)
357 | {
358 | AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor");
359 | AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
360 | AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor");
361 | AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");
362 | AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor");
363 |
364 | const int batch = input.size(0);
365 | const int channels = input.size(1);
366 | const int height = input.size(2);
367 | const int width = input.size(3);
368 | const int channels_trans = no_trans ? 2 : trans.size(1);
369 | const int num_bbox = bbox.size(0);
370 |
371 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
372 | auto pooled_height = pooled_size;
373 | auto pooled_width = pooled_size;
374 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
375 | const int num_classes = no_trans ? 1 : channels_trans / 2;
376 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
377 |
378 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
379 | auto trans_grad = at::zeros_like(trans);
380 |
381 | if (input_grad.numel() == 0)
382 | {
383 | THCudaCheck(cudaGetLastError());
384 | return std::make_tuple(input_grad, trans_grad);
385 | }
386 |
387 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
388 | dim3 block(512);
389 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
390 |
391 | AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cuda_backward", [&] {
392 | DeformablePSROIPoolBackwardAccKernelCuda<<>>(
393 | out_size,
394 | out_grad.contiguous().data_ptr(),
395 | top_count.contiguous().data_ptr(),
396 | num_bbox,
397 | spatial_scale,
398 | channels,
399 | height,
400 | width,
401 | pooled_height,
402 | pooled_width,
403 | output_dim,
404 | input_grad.contiguous().data_ptr(),
405 | trans_grad.contiguous().data_ptr(),
406 | input.contiguous().data_ptr(),
407 | bbox.contiguous().data_ptr(),
408 | trans.contiguous().data_ptr(),
409 | no_trans,
410 | trans_std,
411 | sample_per_part,
412 | group_size,
413 | part_size,
414 | num_classes,
415 | channels_each_class);
416 | });
417 | THCudaCheck(cudaGetLastError());
418 | return std::make_tuple(input_grad, trans_grad);
419 | }
--------------------------------------------------------------------------------
/src/cuda/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cuda_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cuda_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/src/dcn_v2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "cpu/vision.h"
4 |
5 | #ifdef WITH_CUDA
6 | #include "cuda/vision.h"
7 | #endif
8 |
9 | at::Tensor
10 | dcn_v2_forward(const at::Tensor &input,
11 | const at::Tensor &weight,
12 | const at::Tensor &bias,
13 | const at::Tensor &offset,
14 | const at::Tensor &mask,
15 | const int kernel_h,
16 | const int kernel_w,
17 | const int stride_h,
18 | const int stride_w,
19 | const int pad_h,
20 | const int pad_w,
21 | const int dilation_h,
22 | const int dilation_w,
23 | const int deformable_group)
24 | {
25 | if (input.is_cuda())
26 | {
27 | #ifdef WITH_CUDA
28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
29 | kernel_h, kernel_w,
30 | stride_h, stride_w,
31 | pad_h, pad_w,
32 | dilation_h, dilation_w,
33 | deformable_group);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | else{
39 | return dcn_v2_cpu_forward(input, weight, bias, offset, mask,
40 | kernel_h, kernel_w,
41 | stride_h, stride_w,
42 | pad_h, pad_w,
43 | dilation_h, dilation_w,
44 | deformable_group);
45 | }
46 | }
47 |
48 | std::vector
49 | dcn_v2_backward(const at::Tensor &input,
50 | const at::Tensor &weight,
51 | const at::Tensor &bias,
52 | const at::Tensor &offset,
53 | const at::Tensor &mask,
54 | const at::Tensor &grad_output,
55 | int kernel_h, int kernel_w,
56 | int stride_h, int stride_w,
57 | int pad_h, int pad_w,
58 | int dilation_h, int dilation_w,
59 | int deformable_group)
60 | {
61 | if (input.is_cuda())
62 | {
63 | #ifdef WITH_CUDA
64 | return dcn_v2_cuda_backward(input,
65 | weight,
66 | bias,
67 | offset,
68 | mask,
69 | grad_output,
70 | kernel_h, kernel_w,
71 | stride_h, stride_w,
72 | pad_h, pad_w,
73 | dilation_h, dilation_w,
74 | deformable_group);
75 | #else
76 | AT_ERROR("Not compiled with GPU support");
77 | #endif
78 | }
79 | else{
80 | return dcn_v2_cpu_backward(input,
81 | weight,
82 | bias,
83 | offset,
84 | mask,
85 | grad_output,
86 | kernel_h, kernel_w,
87 | stride_h, stride_w,
88 | pad_h, pad_w,
89 | dilation_h, dilation_w,
90 | deformable_group);
91 | }
92 | }
93 |
94 | std::tuple
95 | dcn_v2_psroi_pooling_forward(const at::Tensor &input,
96 | const at::Tensor &bbox,
97 | const at::Tensor &trans,
98 | const int no_trans,
99 | const float spatial_scale,
100 | const int output_dim,
101 | const int group_size,
102 | const int pooled_size,
103 | const int part_size,
104 | const int sample_per_part,
105 | const float trans_std)
106 | {
107 | if (input.is_cuda())
108 | {
109 | #ifdef WITH_CUDA
110 | return dcn_v2_psroi_pooling_cuda_forward(input,
111 | bbox,
112 | trans,
113 | no_trans,
114 | spatial_scale,
115 | output_dim,
116 | group_size,
117 | pooled_size,
118 | part_size,
119 | sample_per_part,
120 | trans_std);
121 | #else
122 | AT_ERROR("Not compiled with GPU support");
123 | #endif
124 | }
125 | else{
126 | return dcn_v2_psroi_pooling_cpu_forward(input,
127 | bbox,
128 | trans,
129 | no_trans,
130 | spatial_scale,
131 | output_dim,
132 | group_size,
133 | pooled_size,
134 | part_size,
135 | sample_per_part,
136 | trans_std);
137 | }
138 | }
139 |
140 | std::tuple
141 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
142 | const at::Tensor &input,
143 | const at::Tensor &bbox,
144 | const at::Tensor &trans,
145 | const at::Tensor &top_count,
146 | const int no_trans,
147 | const float spatial_scale,
148 | const int output_dim,
149 | const int group_size,
150 | const int pooled_size,
151 | const int part_size,
152 | const int sample_per_part,
153 | const float trans_std)
154 | {
155 | if (input.is_cuda())
156 | {
157 | #ifdef WITH_CUDA
158 | return dcn_v2_psroi_pooling_cuda_backward(out_grad,
159 | input,
160 | bbox,
161 | trans,
162 | top_count,
163 | no_trans,
164 | spatial_scale,
165 | output_dim,
166 | group_size,
167 | pooled_size,
168 | part_size,
169 | sample_per_part,
170 | trans_std);
171 | #else
172 | AT_ERROR("Not compiled with GPU support");
173 | #endif
174 | }
175 | else{
176 | return dcn_v2_psroi_pooling_cpu_backward(out_grad,
177 | input,
178 | bbox,
179 | trans,
180 | top_count,
181 | no_trans,
182 | spatial_scale,
183 | output_dim,
184 | group_size,
185 | pooled_size,
186 | part_size,
187 | sample_per_part,
188 | trans_std);
189 | }
190 | }
--------------------------------------------------------------------------------
/src/vision.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "dcn_v2.h"
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
9 | }
10 |
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import gradcheck
7 |
8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling
9 |
10 | deformable_groups = 1
11 | N, inC, inH, inW = 2, 2, 4, 4
12 | outC = 2
13 | kH, kW = 3, 3
14 |
15 |
16 | def conv_identify(weight, bias):
17 | weight.data.zero_()
18 | bias.data.zero_()
19 | o, i, h, w = weight.shape
20 | y = h // 2
21 | x = w // 2
22 | for p in range(i):
23 | for q in range(o):
24 | if p == q:
25 | weight.data[q, p, y, x] = 1.0
26 |
27 |
28 | def check_zero_offset():
29 | conv_offset = nn.Conv2d(
30 | inC,
31 | deformable_groups * 2 * kH * kW,
32 | kernel_size=(kH, kW),
33 | stride=(1, 1),
34 | padding=(1, 1),
35 | bias=True,
36 | ).cuda()
37 |
38 | conv_mask = nn.Conv2d(
39 | inC,
40 | deformable_groups * 1 * kH * kW,
41 | kernel_size=(kH, kW),
42 | stride=(1, 1),
43 | padding=(1, 1),
44 | bias=True,
45 | ).cuda()
46 |
47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda()
48 |
49 | conv_offset.weight.data.zero_()
50 | conv_offset.bias.data.zero_()
51 | conv_mask.weight.data.zero_()
52 | conv_mask.bias.data.zero_()
53 | conv_identify(dcn_v2.weight, dcn_v2.bias)
54 |
55 | input = torch.randn(N, inC, inH, inW).cuda()
56 | offset = conv_offset(input)
57 | mask = conv_mask(input)
58 | mask = torch.sigmoid(mask)
59 | output = dcn_v2(input, offset, mask)
60 | output *= 2
61 | d = (input - output).abs().max()
62 | if d < 1e-10:
63 | print("Zero offset passed")
64 | else:
65 | print("Zero offset failed")
66 | print(input)
67 | print(output)
68 |
69 |
70 | def check_gradient_dconv():
71 |
72 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01
73 | input.requires_grad = True
74 |
75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2
76 | # offset.data.zero_()
77 | # offset.data -= 0.5
78 | offset.requires_grad = True
79 |
80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
81 | # mask.data.zero_()
82 | mask.requires_grad = True
83 | mask = torch.sigmoid(mask)
84 |
85 | weight = torch.randn(outC, inC, kH, kW).cuda()
86 | weight.requires_grad = True
87 |
88 | bias = torch.rand(outC).cuda()
89 | bias.requires_grad = True
90 |
91 | stride = 1
92 | padding = 1
93 | dilation = 1
94 |
95 | print(
96 | "check_gradient_dconv: ",
97 | gradcheck(
98 | dcn_v2_conv,
99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups),
100 | eps=1e-3,
101 | atol=1e-4,
102 | rtol=1e-2,
103 | ),
104 | )
105 |
106 |
107 | def check_pooling_zero_offset():
108 |
109 | input = torch.randn(2, 16, 64, 64).cuda().zero_()
110 | input[0, :, 16:26, 16:26] = 1.0
111 | input[1, :, 10:20, 20:30] = 2.0
112 | rois = (
113 | torch.tensor(
114 | [
115 | [0, 65, 65, 103, 103],
116 | [1, 81, 41, 119, 79],
117 | ]
118 | )
119 | .cuda()
120 | .float()
121 | )
122 | pooling = DCNv2Pooling(
123 | spatial_scale=1.0 / 4,
124 | pooled_size=7,
125 | output_dim=16,
126 | no_trans=True,
127 | group_size=1,
128 | trans_std=0.0,
129 | ).cuda()
130 |
131 | out = pooling(input, rois, input.new())
132 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])])
133 | print(s)
134 |
135 | dpooling = DCNv2Pooling(
136 | spatial_scale=1.0 / 4,
137 | pooled_size=7,
138 | output_dim=16,
139 | no_trans=False,
140 | group_size=1,
141 | trans_std=0.0,
142 | ).cuda()
143 | offset = torch.randn(20, 2, 7, 7).cuda().zero_()
144 | dout = dpooling(input, rois, offset)
145 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])])
146 | print(s)
147 |
148 |
149 | def check_gradient_dpooling():
150 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01
151 | N = 4
152 | batch_inds = torch.randint(2, (N, 1)).cuda().float()
153 | x = torch.rand((N, 1)).cuda().float() * 15
154 | y = torch.rand((N, 1)).cuda().float() * 15
155 | w = torch.rand((N, 1)).cuda().float() * 10
156 | h = torch.rand((N, 1)).cuda().float() * 10
157 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
158 | offset = torch.randn(N, 2, 3, 3).cuda()
159 | input.requires_grad = True
160 | offset.requires_grad = True
161 |
162 | spatial_scale = 1.0 / 4
163 | pooled_size = 3
164 | output_dim = 3
165 | no_trans = 0
166 | group_size = 1
167 | trans_std = 0.0
168 | sample_per_part = 4
169 | part_size = pooled_size
170 |
171 | print(
172 | "check_gradient_dpooling:",
173 | gradcheck(
174 | dcn_v2_pooling,
175 | (
176 | input,
177 | rois,
178 | offset,
179 | spatial_scale,
180 | pooled_size,
181 | output_dim,
182 | no_trans,
183 | group_size,
184 | part_size,
185 | sample_per_part,
186 | trans_std,
187 | ),
188 | eps=1e-4,
189 | ),
190 | )
191 |
192 |
193 | def example_dconv():
194 | input = torch.randn(2, 64, 128, 128).cuda()
195 | # wrap all things (offset and mask) in DCN
196 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda()
197 | # print(dcn.weight.shape, input.shape)
198 | output = dcn(input)
199 | targert = output.new(*output.size())
200 | targert.data.uniform_(-0.01, 0.01)
201 | error = (targert - output).mean()
202 | error.backward()
203 | print(output.shape)
204 |
205 |
206 | def example_dpooling():
207 | input = torch.randn(2, 32, 64, 64).cuda()
208 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
209 | x = torch.randint(256, (20, 1)).cuda().float()
210 | y = torch.randint(256, (20, 1)).cuda().float()
211 | w = torch.randint(64, (20, 1)).cuda().float()
212 | h = torch.randint(64, (20, 1)).cuda().float()
213 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
214 | offset = torch.randn(20, 2, 7, 7).cuda()
215 | input.requires_grad = True
216 | offset.requires_grad = True
217 |
218 | # normal roi_align
219 | pooling = DCNv2Pooling(
220 | spatial_scale=1.0 / 4,
221 | pooled_size=7,
222 | output_dim=32,
223 | no_trans=True,
224 | group_size=1,
225 | trans_std=0.1,
226 | ).cuda()
227 |
228 | # deformable pooling
229 | dpooling = DCNv2Pooling(
230 | spatial_scale=1.0 / 4,
231 | pooled_size=7,
232 | output_dim=32,
233 | no_trans=False,
234 | group_size=1,
235 | trans_std=0.1,
236 | ).cuda()
237 |
238 | out = pooling(input, rois, offset)
239 | dout = dpooling(input, rois, offset)
240 | print(out.shape)
241 | print(dout.shape)
242 |
243 | target_out = out.new(*out.size())
244 | target_out.data.uniform_(-0.01, 0.01)
245 | target_dout = dout.new(*dout.size())
246 | target_dout.data.uniform_(-0.01, 0.01)
247 | e = (target_out - out).mean()
248 | e.backward()
249 | e = (target_dout - dout).mean()
250 | e.backward()
251 |
252 |
253 | def example_mdpooling():
254 | input = torch.randn(2, 32, 64, 64).cuda()
255 | input.requires_grad = True
256 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
257 | x = torch.randint(256, (20, 1)).cuda().float()
258 | y = torch.randint(256, (20, 1)).cuda().float()
259 | w = torch.randint(64, (20, 1)).cuda().float()
260 | h = torch.randint(64, (20, 1)).cuda().float()
261 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
262 |
263 | # mdformable pooling (V2)
264 | dpooling = DCNPooling(
265 | spatial_scale=1.0 / 4,
266 | pooled_size=7,
267 | output_dim=32,
268 | no_trans=False,
269 | group_size=1,
270 | trans_std=0.1,
271 | deform_fc_dim=1024,
272 | ).cuda()
273 |
274 | dout = dpooling(input, rois)
275 | target = dout.new(*dout.size())
276 | target.data.uniform_(-0.1, 0.1)
277 | error = (target - dout).mean()
278 | error.backward()
279 | print(dout.shape)
280 |
281 |
282 | if __name__ == "__main__":
283 |
284 | example_dconv()
285 | example_dpooling()
286 | example_mdpooling()
287 |
288 | check_pooling_zero_offset()
289 | # zero offset check
290 | if inC == outC:
291 | check_zero_offset()
292 |
293 | check_gradient_dpooling()
294 | check_gradient_dconv()
295 | # """
296 | # ****** Note: backward is not reentrant error may not be a serious problem,
297 | # ****** since the max error is less than 1e-7,
298 | # ****** Still looking for what trigger this problem
299 | # """
300 |
--------------------------------------------------------------------------------
/test/testcpu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import gradcheck
7 |
8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling
9 |
10 | deformable_groups = 1
11 | N, inC, inH, inW = 2, 2, 4, 4
12 | outC = 2
13 | kH, kW = 3, 3
14 |
15 |
16 | def conv_identify(weight, bias):
17 | weight.data.zero_()
18 | bias.data.zero_()
19 | o, i, h, w = weight.shape
20 | y = h // 2
21 | x = w // 2
22 | for p in range(i):
23 | for q in range(o):
24 | if p == q:
25 | weight.data[q, p, y, x] = 1.0
26 |
27 |
28 | def check_zero_offset():
29 | conv_offset = nn.Conv2d(
30 | inC,
31 | deformable_groups * 2 * kH * kW,
32 | kernel_size=(kH, kW),
33 | stride=(1, 1),
34 | padding=(1, 1),
35 | bias=True,
36 | )
37 |
38 | conv_mask = nn.Conv2d(
39 | inC,
40 | deformable_groups * 1 * kH * kW,
41 | kernel_size=(kH, kW),
42 | stride=(1, 1),
43 | padding=(1, 1),
44 | bias=True,
45 | )
46 |
47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups)
48 |
49 | conv_offset.weight.data.zero_()
50 | conv_offset.bias.data.zero_()
51 | conv_mask.weight.data.zero_()
52 | conv_mask.bias.data.zero_()
53 | conv_identify(dcn_v2.weight, dcn_v2.bias)
54 |
55 | input = torch.randn(N, inC, inH, inW)
56 | offset = conv_offset(input)
57 | mask = conv_mask(input)
58 | mask = torch.sigmoid(mask)
59 | output = dcn_v2(input, offset, mask)
60 | output *= 2
61 | d = (input - output).abs().max()
62 | if d < 1e-10:
63 | print("Zero offset passed")
64 | else:
65 | print("Zero offset failed")
66 | print(input)
67 | print(output)
68 |
69 |
70 | def check_gradient_dconv():
71 |
72 | input = torch.rand(N, inC, inH, inW) * 0.01
73 | input.requires_grad = True
74 |
75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW) * 2
76 | # offset.data.zero_()
77 | # offset.data -= 0.5
78 | offset.requires_grad = True
79 |
80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW)
81 | # mask.data.zero_()
82 | mask.requires_grad = True
83 | mask = torch.sigmoid(mask)
84 |
85 | weight = torch.randn(outC, inC, kH, kW)
86 | weight.requires_grad = True
87 |
88 | bias = torch.rand(outC)
89 | bias.requires_grad = True
90 |
91 | stride = 1
92 | padding = 1
93 | dilation = 1
94 |
95 | print(
96 | "check_gradient_dconv: ",
97 | gradcheck(
98 | dcn_v2_conv,
99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups),
100 | eps=1e-3,
101 | atol=1e-4,
102 | rtol=1e-2,
103 | ),
104 | )
105 |
106 |
107 | def check_pooling_zero_offset():
108 |
109 | input = torch.randn(2, 16, 64, 64).zero_()
110 | input[0, :, 16:26, 16:26] = 1.0
111 | input[1, :, 10:20, 20:30] = 2.0
112 | rois = torch.tensor(
113 | [
114 | [0, 65, 65, 103, 103],
115 | [1, 81, 41, 119, 79],
116 | ]
117 | ).float()
118 | pooling = DCNv2Pooling(
119 | spatial_scale=1.0 / 4,
120 | pooled_size=7,
121 | output_dim=16,
122 | no_trans=True,
123 | group_size=1,
124 | trans_std=0.0,
125 | )
126 |
127 | out = pooling(input, rois, input.new())
128 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])])
129 | print(s)
130 |
131 | dpooling = DCNv2Pooling(
132 | spatial_scale=1.0 / 4,
133 | pooled_size=7,
134 | output_dim=16,
135 | no_trans=False,
136 | group_size=1,
137 | trans_std=0.0,
138 | )
139 | offset = torch.randn(20, 2, 7, 7).zero_()
140 | dout = dpooling(input, rois, offset)
141 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])])
142 | print(s)
143 |
144 |
145 | def check_gradient_dpooling():
146 | input = torch.randn(2, 3, 5, 5) * 0.01
147 | N = 4
148 | batch_inds = torch.randint(2, (N, 1)).float()
149 | x = torch.rand((N, 1)).float() * 15
150 | y = torch.rand((N, 1)).float() * 15
151 | w = torch.rand((N, 1)).float() * 10
152 | h = torch.rand((N, 1)).float() * 10
153 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
154 | offset = torch.randn(N, 2, 3, 3)
155 | input.requires_grad = True
156 | offset.requires_grad = True
157 |
158 | spatial_scale = 1.0 / 4
159 | pooled_size = 3
160 | output_dim = 3
161 | no_trans = 0
162 | group_size = 1
163 | trans_std = 0.0
164 | sample_per_part = 4
165 | part_size = pooled_size
166 |
167 | print(
168 | "check_gradient_dpooling:",
169 | gradcheck(
170 | dcn_v2_pooling,
171 | (
172 | input,
173 | rois,
174 | offset,
175 | spatial_scale,
176 | pooled_size,
177 | output_dim,
178 | no_trans,
179 | group_size,
180 | part_size,
181 | sample_per_part,
182 | trans_std,
183 | ),
184 | eps=1e-4,
185 | ),
186 | )
187 |
188 |
189 | def example_dconv():
190 | input = torch.randn(2, 64, 128, 128)
191 | # wrap all things (offset and mask) in DCN
192 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2)
193 | # print(dcn.weight.shape, input.shape)
194 | output = dcn(input)
195 | targert = output.new(*output.size())
196 | targert.data.uniform_(-0.01, 0.01)
197 | error = (targert - output).mean()
198 | error.backward()
199 | print(output.shape)
200 |
201 |
202 | def example_dpooling():
203 | input = torch.randn(2, 32, 64, 64)
204 | batch_inds = torch.randint(2, (20, 1)).float()
205 | x = torch.randint(256, (20, 1)).float()
206 | y = torch.randint(256, (20, 1)).float()
207 | w = torch.randint(64, (20, 1)).float()
208 | h = torch.randint(64, (20, 1)).float()
209 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
210 | offset = torch.randn(20, 2, 7, 7)
211 | input.requires_grad = True
212 | offset.requires_grad = True
213 |
214 | # normal roi_align
215 | pooling = DCNv2Pooling(
216 | spatial_scale=1.0 / 4,
217 | pooled_size=7,
218 | output_dim=32,
219 | no_trans=True,
220 | group_size=1,
221 | trans_std=0.1,
222 | )
223 |
224 | # deformable pooling
225 | dpooling = DCNv2Pooling(
226 | spatial_scale=1.0 / 4,
227 | pooled_size=7,
228 | output_dim=32,
229 | no_trans=False,
230 | group_size=1,
231 | trans_std=0.1,
232 | )
233 |
234 | out = pooling(input, rois, offset)
235 | dout = dpooling(input, rois, offset)
236 | print(out.shape)
237 | print(dout.shape)
238 |
239 | target_out = out.new(*out.size())
240 | target_out.data.uniform_(-0.01, 0.01)
241 | target_dout = dout.new(*dout.size())
242 | target_dout.data.uniform_(-0.01, 0.01)
243 | e = (target_out - out).mean()
244 | e.backward()
245 | e = (target_dout - dout).mean()
246 | e.backward()
247 |
248 |
249 | def example_mdpooling():
250 | input = torch.randn(2, 32, 64, 64)
251 | input.requires_grad = True
252 | batch_inds = torch.randint(2, (20, 1)).float()
253 | x = torch.randint(256, (20, 1)).float()
254 | y = torch.randint(256, (20, 1)).float()
255 | w = torch.randint(64, (20, 1)).float()
256 | h = torch.randint(64, (20, 1)).float()
257 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
258 |
259 | # mdformable pooling (V2)
260 | dpooling = DCNPooling(
261 | spatial_scale=1.0 / 4,
262 | pooled_size=7,
263 | output_dim=32,
264 | no_trans=False,
265 | group_size=1,
266 | trans_std=0.1,
267 | deform_fc_dim=1024,
268 | )
269 |
270 | dout = dpooling(input, rois)
271 | target = dout.new(*dout.size())
272 | target.data.uniform_(-0.1, 0.1)
273 | error = (target - dout).mean()
274 | error.backward()
275 | print(dout.shape)
276 |
277 |
278 | if __name__ == "__main__":
279 |
280 | example_dconv()
281 | example_dpooling()
282 | example_mdpooling()
283 |
284 | check_pooling_zero_offset()
285 | # zero offset check
286 | if inC == outC:
287 | check_zero_offset()
288 |
289 | check_gradient_dpooling()
290 | check_gradient_dconv()
291 | # """
292 | # ****** Note: backward is not reentrant error may not be a serious problem,
293 | # ****** since the max error is less than 1e-7,
294 | # ****** Still looking for what trigger this problem
295 | # """
296 |
--------------------------------------------------------------------------------
/test/testcuda.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import gradcheck
7 |
8 | from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling
9 |
10 | deformable_groups = 1
11 | N, inC, inH, inW = 2, 2, 4, 4
12 | outC = 2
13 | kH, kW = 3, 3
14 |
15 |
16 | def conv_identify(weight, bias):
17 | weight.data.zero_()
18 | bias.data.zero_()
19 | o, i, h, w = weight.shape
20 | y = h // 2
21 | x = w // 2
22 | for p in range(i):
23 | for q in range(o):
24 | if p == q:
25 | weight.data[q, p, y, x] = 1.0
26 |
27 |
28 | def check_zero_offset():
29 | conv_offset = nn.Conv2d(
30 | inC,
31 | deformable_groups * 2 * kH * kW,
32 | kernel_size=(kH, kW),
33 | stride=(1, 1),
34 | padding=(1, 1),
35 | bias=True,
36 | ).cuda()
37 |
38 | conv_mask = nn.Conv2d(
39 | inC,
40 | deformable_groups * 1 * kH * kW,
41 | kernel_size=(kH, kW),
42 | stride=(1, 1),
43 | padding=(1, 1),
44 | bias=True,
45 | ).cuda()
46 |
47 | dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda()
48 |
49 | conv_offset.weight.data.zero_()
50 | conv_offset.bias.data.zero_()
51 | conv_mask.weight.data.zero_()
52 | conv_mask.bias.data.zero_()
53 | conv_identify(dcn_v2.weight, dcn_v2.bias)
54 |
55 | input = torch.randn(N, inC, inH, inW).cuda()
56 | offset = conv_offset(input)
57 | mask = conv_mask(input)
58 | mask = torch.sigmoid(mask)
59 | output = dcn_v2(input, offset, mask)
60 | output *= 2
61 | d = (input - output).abs().max()
62 | if d < 1e-10:
63 | print("Zero offset passed")
64 | else:
65 | print("Zero offset failed")
66 | print(input)
67 | print(output)
68 |
69 |
70 | def check_gradient_dconv():
71 |
72 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01
73 | input.requires_grad = True
74 |
75 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2
76 | # offset.data.zero_()
77 | # offset.data -= 0.5
78 | offset.requires_grad = True
79 |
80 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
81 | # mask.data.zero_()
82 | mask.requires_grad = True
83 | mask = torch.sigmoid(mask)
84 |
85 | weight = torch.randn(outC, inC, kH, kW).cuda()
86 | weight.requires_grad = True
87 |
88 | bias = torch.rand(outC).cuda()
89 | bias.requires_grad = True
90 |
91 | stride = 1
92 | padding = 1
93 | dilation = 1
94 |
95 | print(
96 | "check_gradient_dconv: ",
97 | gradcheck(
98 | dcn_v2_conv,
99 | (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups),
100 | eps=1e-3,
101 | atol=1e-4,
102 | rtol=1e-2,
103 | ),
104 | )
105 |
106 |
107 | def check_pooling_zero_offset():
108 |
109 | input = torch.randn(2, 16, 64, 64).cuda().zero_()
110 | input[0, :, 16:26, 16:26] = 1.0
111 | input[1, :, 10:20, 20:30] = 2.0
112 | rois = (
113 | torch.tensor(
114 | [
115 | [0, 65, 65, 103, 103],
116 | [1, 81, 41, 119, 79],
117 | ]
118 | )
119 | .cuda()
120 | .float()
121 | )
122 | pooling = DCNv2Pooling(
123 | spatial_scale=1.0 / 4,
124 | pooled_size=7,
125 | output_dim=16,
126 | no_trans=True,
127 | group_size=1,
128 | trans_std=0.0,
129 | ).cuda()
130 |
131 | out = pooling(input, rois, input.new())
132 | s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])])
133 | print(s)
134 |
135 | dpooling = DCNv2Pooling(
136 | spatial_scale=1.0 / 4,
137 | pooled_size=7,
138 | output_dim=16,
139 | no_trans=False,
140 | group_size=1,
141 | trans_std=0.0,
142 | ).cuda()
143 | offset = torch.randn(20, 2, 7, 7).cuda().zero_()
144 | dout = dpooling(input, rois, offset)
145 | s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])])
146 | print(s)
147 |
148 |
149 | def check_gradient_dpooling():
150 | input = torch.randn(2, 3, 5, 5).cuda().float() * 0.01
151 | N = 4
152 | batch_inds = torch.randint(2, (N, 1)).cuda().float()
153 | x = torch.rand((N, 1)).cuda().float() * 15
154 | y = torch.rand((N, 1)).cuda().float() * 15
155 | w = torch.rand((N, 1)).cuda().float() * 10
156 | h = torch.rand((N, 1)).cuda().float() * 10
157 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
158 | offset = torch.randn(N, 2, 3, 3).cuda()
159 | input.requires_grad = True
160 | offset.requires_grad = True
161 |
162 | spatial_scale = 1.0 / 4
163 | pooled_size = 3
164 | output_dim = 3
165 | no_trans = 0
166 | group_size = 1
167 | trans_std = 0.0
168 | sample_per_part = 4
169 | part_size = pooled_size
170 |
171 | print(
172 | "check_gradient_dpooling:",
173 | gradcheck(
174 | dcn_v2_pooling,
175 | (
176 | input,
177 | rois,
178 | offset,
179 | spatial_scale,
180 | pooled_size,
181 | output_dim,
182 | no_trans,
183 | group_size,
184 | part_size,
185 | sample_per_part,
186 | trans_std,
187 | ),
188 | eps=1e-4,
189 | ),
190 | )
191 |
192 |
193 | def example_dconv():
194 | input = torch.randn(2, 64, 128, 128).cuda()
195 | # wrap all things (offset and mask) in DCN
196 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda()
197 | # print(dcn.weight.shape, input.shape)
198 | output = dcn(input)
199 | targert = output.new(*output.size())
200 | targert.data.uniform_(-0.01, 0.01)
201 | error = (targert - output).mean()
202 | error.backward()
203 | print(output.shape)
204 |
205 |
206 | def example_dpooling():
207 | input = torch.randn(2, 32, 64, 64).cuda()
208 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
209 | x = torch.randint(256, (20, 1)).cuda().float()
210 | y = torch.randint(256, (20, 1)).cuda().float()
211 | w = torch.randint(64, (20, 1)).cuda().float()
212 | h = torch.randint(64, (20, 1)).cuda().float()
213 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
214 | offset = torch.randn(20, 2, 7, 7).cuda()
215 | input.requires_grad = True
216 | offset.requires_grad = True
217 |
218 | # normal roi_align
219 | pooling = DCNv2Pooling(
220 | spatial_scale=1.0 / 4,
221 | pooled_size=7,
222 | output_dim=32,
223 | no_trans=True,
224 | group_size=1,
225 | trans_std=0.1,
226 | ).cuda()
227 |
228 | # deformable pooling
229 | dpooling = DCNv2Pooling(
230 | spatial_scale=1.0 / 4,
231 | pooled_size=7,
232 | output_dim=32,
233 | no_trans=False,
234 | group_size=1,
235 | trans_std=0.1,
236 | ).cuda()
237 |
238 | out = pooling(input, rois, offset)
239 | dout = dpooling(input, rois, offset)
240 | print(out.shape)
241 | print(dout.shape)
242 |
243 | target_out = out.new(*out.size())
244 | target_out.data.uniform_(-0.01, 0.01)
245 | target_dout = dout.new(*dout.size())
246 | target_dout.data.uniform_(-0.01, 0.01)
247 | e = (target_out - out).mean()
248 | e.backward()
249 | e = (target_dout - dout).mean()
250 | e.backward()
251 |
252 |
253 | def example_mdpooling():
254 | input = torch.randn(2, 32, 64, 64).cuda()
255 | input.requires_grad = True
256 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
257 | x = torch.randint(256, (20, 1)).cuda().float()
258 | y = torch.randint(256, (20, 1)).cuda().float()
259 | w = torch.randint(64, (20, 1)).cuda().float()
260 | h = torch.randint(64, (20, 1)).cuda().float()
261 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
262 |
263 | # mdformable pooling (V2)
264 | dpooling = DCNPooling(
265 | spatial_scale=1.0 / 4,
266 | pooled_size=7,
267 | output_dim=32,
268 | no_trans=False,
269 | group_size=1,
270 | trans_std=0.1,
271 | deform_fc_dim=1024,
272 | ).cuda()
273 |
274 | dout = dpooling(input, rois)
275 | target = dout.new(*dout.size())
276 | target.data.uniform_(-0.1, 0.1)
277 | error = (target - dout).mean()
278 | error.backward()
279 | print(dout.shape)
280 |
281 |
282 | if __name__ == "__main__":
283 |
284 | example_dconv()
285 | example_dpooling()
286 | example_mdpooling()
287 |
288 | check_pooling_zero_offset()
289 | # zero offset check
290 | if inC == outC:
291 | check_zero_offset()
292 |
293 | check_gradient_dpooling()
294 | check_gradient_dconv()
295 | # """
296 | # ****** Note: backward is not reentrant error may not be a serious problem,
297 | # ****** since the max error is less than 1e-7,
298 | # ****** Still looking for what trigger this problem
299 | # """
300 |
--------------------------------------------------------------------------------
| | |