├── .gitignore
├── README.md
├── build.py
├── functions
├── __init__.py
└── deform_conv.py
├── make.sh
├── modules
├── __init__.py
└── deform_conv.py
├── src
├── deform_conv.c
├── deform_conv.h
├── deform_conv_cuda.c
├── deform_conv_cuda.h
├── deform_conv_cuda_kernel.cu
└── deform_conv_cuda_kernel.h
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | **/*.pyc
3 | **/_ext
4 | **/build
5 | **/dist
6 | **/*.egg-info
7 | **/.eggs
8 | .clang_complete
9 | *.o
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deformable Convolutional Networks in PyTorch
2 | This repo is an implementation of [Deformable Convolution](https://arxiv.org/abs/1703.06211).
3 | Ported from author's MXNet [implementation](https://github.com/msracver/Deformable-ConvNets).
4 |
5 | # Build
6 |
7 | ```
8 | sh make.sh
9 | CC=g++ python build.py
10 | ```
11 |
12 | See `test.py` for example usage.
13 |
14 | ### Notice
15 | Only `torch.cuda.FloatTensor` is supported.
16 |
--------------------------------------------------------------------------------
/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.ffi import create_extension
4 |
5 | this_file = os.path.dirname(__file__)
6 |
7 | sources = ['src/deform_conv.c']
8 | headers = ['src/deform_conv.h']
9 | defines = []
10 | with_cuda = False
11 |
12 | if torch.cuda.is_available():
13 | print('Including CUDA code.')
14 | sources += ['src/deform_conv_cuda.c']
15 | headers += ['src/deform_conv_cuda.h']
16 | defines += [('WITH_CUDA', None)]
17 | with_cuda = True
18 |
19 | this_file = os.path.dirname(os.path.realpath(__file__))
20 | print(this_file)
21 | extra_objects = ['src/deform_conv_cuda_kernel.cu.o']
22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
23 |
24 | ffi = create_extension(
25 | '_ext.deform_conv',
26 | headers=headers,
27 | sources=sources,
28 | define_macros=defines,
29 | relative_to=__file__,
30 | with_cuda=with_cuda,
31 | extra_objects=extra_objects
32 | )
33 |
34 | if __name__ == '__main__':
35 | assert torch.cuda.is_available(), 'Please install CUDA for GPU support.'
36 | ffi.build()
37 |
--------------------------------------------------------------------------------
/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import conv_offset2d
2 |
--------------------------------------------------------------------------------
/functions/deform_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.nn.modules.utils import _pair
4 |
5 | from _ext import deform_conv
6 |
7 |
8 | def conv_offset2d(input,
9 | offset,
10 | weight,
11 | stride=1,
12 | padding=0,
13 | dilation=1,
14 | deform_groups=1):
15 |
16 | if input is not None and input.dim() != 4:
17 | raise ValueError(
18 | "Expected 4D tensor as input, got {}D tensor instead.".format(
19 | input.dim()))
20 |
21 | f = ConvOffset2dFunction(
22 | _pair(stride), _pair(padding), _pair(dilation), deform_groups)
23 | return f(input, offset, weight)
24 |
25 |
26 | class ConvOffset2dFunction(Function):
27 | def __init__(self, stride, padding, dilation, deformable_groups=1):
28 | super(ConvOffset2dFunction, self).__init__()
29 | self.stride = stride
30 | self.padding = padding
31 | self.dilation = dilation
32 | self.deformable_groups = deformable_groups
33 |
34 | def forward(self, input, offset, weight):
35 | self.save_for_backward(input, offset, weight)
36 |
37 | output = input.new(*self._output_size(input, weight))
38 |
39 | self.bufs_ = [input.new(), input.new()] # columns, ones
40 |
41 | if not input.is_cuda:
42 | raise NotImplementedError
43 | else:
44 | if isinstance(input, torch.autograd.Variable):
45 | if not isinstance(input.data, torch.cuda.FloatTensor):
46 | raise NotImplementedError
47 | else:
48 | if not isinstance(input, torch.cuda.FloatTensor):
49 | raise NotImplementedError
50 | deform_conv.deform_conv_forward_cuda(
51 | input, weight, offset, output, self.bufs_[0], self.bufs_[1],
52 | weight.size(3), weight.size(2), self.stride[1], self.stride[0],
53 | self.padding[1], self.padding[0], self.dilation[1],
54 | self.dilation[0], self.deformable_groups)
55 | return output
56 |
57 | def backward(self, grad_output):
58 | input, offset, weight = self.saved_tensors
59 |
60 | grad_input = grad_offset = grad_weight = None
61 |
62 | if not grad_output.is_cuda:
63 | raise NotImplementedError
64 | else:
65 | if isinstance(grad_output, torch.autograd.Variable):
66 | if not isinstance(grad_output.data, torch.cuda.FloatTensor):
67 | raise NotImplementedError
68 | else:
69 | if not isinstance(grad_output, torch.cuda.FloatTensor):
70 | raise NotImplementedError
71 | if self.needs_input_grad[0] or self.needs_input_grad[1]:
72 | grad_input = input.new(*input.size()).zero_()
73 | grad_offset = offset.new(*offset.size()).zero_()
74 | deform_conv.deform_conv_backward_input_cuda(
75 | input, offset, grad_output, grad_input,
76 | grad_offset, weight, self.bufs_[0], weight.size(3),
77 | weight.size(2), self.stride[1], self.stride[0],
78 | self.padding[1], self.padding[0], self.dilation[1],
79 | self.dilation[0], self.deformable_groups)
80 |
81 | if self.needs_input_grad[2]:
82 | grad_weight = weight.new(*weight.size()).zero_()
83 | deform_conv.deform_conv_backward_parameters_cuda(
84 | input, offset, grad_output,
85 | grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3),
86 | weight.size(2), self.stride[1], self.stride[0],
87 | self.padding[1], self.padding[0], self.dilation[1],
88 | self.dilation[0], self.deformable_groups, 1)
89 |
90 | return grad_input, grad_offset, grad_weight
91 |
92 | def _output_size(self, input, weight):
93 | channels = weight.size(0)
94 |
95 | output_size = (input.size(0), channels)
96 | for d in range(input.dim() - 2):
97 | in_size = input.size(d + 2)
98 | pad = self.padding[d]
99 | kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1
100 | stride = self.stride[d]
101 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
102 | if not all(map(lambda s: s > 0, output_size)):
103 | raise ValueError(
104 | "convolution input is too small (output would be {})".format(
105 | 'x'.join(map(str, output_size))))
106 | return output_size
107 |
--------------------------------------------------------------------------------
/make.sh:
--------------------------------------------------------------------------------
1 | cd src
2 | nvcc -c -o deform_conv_cuda_kernel.cu.o deform_conv_cuda_kernel.cu -x cu -Xcompiler -fPIC -std=c++11
3 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import ConvOffset2d
2 |
--------------------------------------------------------------------------------
/modules/deform_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.module import Module
6 | from torch.nn.modules.utils import _pair
7 | from functions import conv_offset2d
8 |
9 |
10 | class ConvOffset2d(Module):
11 | def __init__(self,
12 | in_channels,
13 | out_channels,
14 | kernel_size,
15 | stride=1,
16 | padding=0,
17 | dilation=1,
18 | num_deformable_groups=1):
19 | super(ConvOffset2d, self).__init__()
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.kernel_size = _pair(kernel_size)
23 | self.stride = _pair(stride)
24 | self.padding = _pair(padding)
25 | self.dilation = _pair(dilation)
26 | self.num_deformable_groups = num_deformable_groups
27 |
28 | self.weight = nn.Parameter(
29 | torch.Tensor(out_channels, in_channels, *self.kernel_size))
30 |
31 | self.reset_parameters()
32 |
33 | def reset_parameters(self):
34 | n = self.in_channels
35 | for k in self.kernel_size:
36 | n *= k
37 | stdv = 1. / math.sqrt(n)
38 | self.weight.data.uniform_(-stdv, stdv)
39 |
40 | def forward(self, input, offset):
41 | return conv_offset2d(input, offset, self.weight, self.stride,
42 | self.padding, self.dilation,
43 | self.num_deformable_groups)
44 |
--------------------------------------------------------------------------------
/src/deform_conv.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset,
4 | THFloatTensor *output)
5 | {
6 | // if (!THFloatTensor_isSameSizeAs(input1, input2))
7 | // return 0;
8 | // THFloatTensor_resizeAs(output, input);
9 | // THFloatTensor_cadd(output, input1, 1.0, input2);
10 | return 1;
11 | }
12 |
13 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input,
14 | THFloatTensor *grad_offset)
15 | {
16 | // THFloatTensor_resizeAs(grad_input, grad_output);
17 | // THFloatTensor_fill(grad_input, 1);
18 | return 1;
19 | }
20 |
--------------------------------------------------------------------------------
/src/deform_conv.h:
--------------------------------------------------------------------------------
1 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset,
2 | THFloatTensor *output);
3 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input,
4 | THFloatTensor *grad_offset);
5 |
--------------------------------------------------------------------------------
/src/deform_conv_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "deform_conv_cuda_kernel.h"
4 |
5 | extern THCState *state;
6 |
7 | void shape_check(THCState *state, THCudaTensor *input, THCudaTensor *offset,
8 | THCudaTensor *gradOutput, THCudaTensor *weight, int kH, int kW,
9 | int dH, int dW, int padH, int padW, int dilationH,
10 | int dilationW, int deformable_group) {
11 |
12 | THArgCheck(weight->nDimension == 4, 5,
13 | "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
14 | "but got: %s",
15 | weight->nDimension);
16 |
17 | THArgCheck(THCudaTensor_isContiguous(state, weight), 5,
18 | "weight tensor has to be contiguous");
19 |
20 | THArgCheck(kW > 0 && kH > 0, 9,
21 | "kernel size should be greater than zero, but got kH: %d kW: %d",
22 | kH, kW);
23 |
24 | THArgCheck((weight->size[2] == kH && weight->size[3] == kW), 9,
25 | "kernel size should be consistent with weight, ",
26 | "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
27 | kW, weight->size[2], weight->size[3]);
28 |
29 | THArgCheck(dW > 0 && dH > 0, 11,
30 | "stride should be greater than zero, but got dH: %d dW: %d", dH,
31 | dW);
32 |
33 | THArgCheck(
34 | dilationW > 0 && dilationH > 0, 14,
35 | "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
36 | dilationH, dilationW);
37 |
38 | int ndim = input->nDimension;
39 | int dimf = 0;
40 | int dimh = 1;
41 | int dimw = 2;
42 |
43 | if (ndim == 4) {
44 | dimf++;
45 | dimh++;
46 | dimw++;
47 | }
48 |
49 | THArgCheck(ndim == 3 || ndim == 4, 2,
50 | "3D or 4D input tensor expected but got: %s", ndim);
51 |
52 | long nInputPlane = weight->size[1];
53 | long inputHeight = input->size[dimh];
54 | long inputWidth = input->size[dimw];
55 | long nOutputPlane = weight->size[0];
56 | long outputHeight =
57 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
58 | long outputWidth =
59 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
60 |
61 | THArgCheck(nInputPlane % deformable_group == 0, 2,
62 | "input channels must divide deformable group size");
63 |
64 | if (outputWidth < 1 || outputHeight < 1)
65 | THError(
66 | "Given input size: (%ld x %ld x %ld). "
67 | "Calculated output size: (%ld x %ld x %ld). Output size is too small",
68 | nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
69 | outputWidth);
70 |
71 | THArgCheck(input->size[1] == nInputPlane, 2,
72 | "invalid number of input planes, expected: %d, but got: %d",
73 | nInputPlane, input->size[1]);
74 |
75 | THArgCheck((inputHeight >= kH && inputWidth >= kW), 2,
76 | "input image is smaller than kernel");
77 |
78 | THArgCheck(
79 | (offset->size[2] == outputHeight && offset->size[3] == outputWidth), 3,
80 | "invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", outputHeight, outputWidth,
81 | offset->size[2], offset->size[3]);
82 |
83 | THArgCheck((offset->size[1] == deformable_group * 2 * kH * kW), 3,
84 | "invalid number of channels of offset");
85 |
86 | if (gradOutput != NULL) {
87 | THArgCheck(gradOutput->size[dimf] == nOutputPlane, 4,
88 | "invalid number of gradOutput planes, expected: %d, but got: %d",
89 | nOutputPlane, gradOutput->size[dimf]);
90 |
91 | THArgCheck((gradOutput->size[dimh] == outputHeight &&
92 | gradOutput->size[dimw] == outputWidth),
93 | 4, "invalid size of gradOutput, expected height: %d width: %d , but got height: %d width: %d", outputHeight, outputWidth,
94 | gradOutput->size[dimh], gradOutput->size[dimw]);
95 | }
96 | }
97 |
98 | int deform_conv_forward_cuda(THCudaTensor *input, THCudaTensor *weight,
99 | THCudaTensor *offset, THCudaTensor *output,
100 | THCudaTensor *columns, THCudaTensor *ones, int kW,
101 | int kH, int dW, int dH, int padW, int padH,
102 | int dilationH, int dilationW,
103 | int deformable_group) {
104 |
105 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, weight, offset,
106 | output, columns, ones));
107 |
108 | shape_check(state, input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
109 | dilationH, dilationW, deformable_group);
110 |
111 | input = THCudaTensor_newContiguous(state, input);
112 | offset = THCudaTensor_newContiguous(state, offset);
113 | weight = THCudaTensor_newContiguous(state, weight);
114 |
115 | int batch = 1;
116 | if (input->nDimension == 3) {
117 | // Force batch
118 | batch = 0;
119 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1],
120 | input->size[2]);
121 | THCudaTensor_resize4d(state, offset, 1, offset->size[0], offset->size[1],
122 | offset->size[2]);
123 | }
124 |
125 | long batchSize = input->size[0];
126 | long nInputPlane = input->size[1];
127 | long inputHeight = input->size[2];
128 | long inputWidth = input->size[3];
129 |
130 | long nOutputPlane = weight->size[0];
131 |
132 | long outputWidth =
133 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
134 | long outputHeight =
135 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
136 |
137 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset");
138 |
139 | // bias = bias ? THCudaTensor_newContiguous(state, bias) : bias;
140 |
141 | THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight,
142 | outputWidth);
143 |
144 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH,
145 | outputHeight * outputWidth);
146 |
147 | if (ones->nDimension != 2 ||
148 | ones->size[0] * ones->size[1] < outputHeight * outputWidth) {
149 | THCudaTensor_resize2d(state, ones, outputHeight, outputWidth);
150 | THCudaTensor_fill(state, ones, 1);
151 | }
152 |
153 | THCudaTensor *input_n = THCudaTensor_new(state);
154 | THCudaTensor *offset_n = THCudaTensor_new(state);
155 | THCudaTensor *output_n = THCudaTensor_new(state);
156 |
157 | for (int elt = 0; elt < batchSize; elt++) {
158 |
159 | THCudaTensor_select(state, input_n, input, 0, elt);
160 | THCudaTensor_select(state, offset_n, offset, 0, elt);
161 | THCudaTensor_select(state, output_n, output, 0, elt);
162 |
163 | // long m_ = nOutputPlane;
164 | // long n_ = outputHeight * outputWidth;
165 | // long k_ = 1;
166 |
167 | // TODO(BZ) add bias term
168 | // if (bias) {
169 | // THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
170 | // THCudaTensor_data(state, ones), k_,
171 | // THCudaTensor_data(state, bias), k_, 0.0f,
172 | // THCudaTensor_data(state, output_n), n_);
173 | // } else {
174 | // THCudaTensor_zero(state, output_n);
175 | // }
176 |
177 | THCudaTensor_zero(state, output_n);
178 |
179 | deformable_im2col(
180 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n),
181 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight,
182 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
183 | deformable_group, THCudaTensor_data(state, columns));
184 |
185 | long m = nOutputPlane;
186 | long n = columns->size[1];
187 | long k = nInputPlane * kH * kW;
188 |
189 | THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f,
190 | THCudaTensor_data(state, columns), n,
191 | THCudaTensor_data(state, weight), k, 1.0f,
192 | THCudaTensor_data(state, output_n), n);
193 | }
194 |
195 | THCudaTensor_free(state, input_n);
196 | THCudaTensor_free(state, offset_n);
197 | THCudaTensor_free(state, output_n);
198 |
199 | if (batch == 0) {
200 | THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight,
201 | outputWidth);
202 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
203 | THCudaTensor_resize3d(state, offset, offset->size[1], offset->size[2],
204 | offset->size[3]);
205 | }
206 |
207 | THCudaTensor_free(state, input);
208 | THCudaTensor_free(state, offset);
209 | THCudaTensor_free(state, weight);
210 | // if (bias) THCudaTensor_free(state, bias);
211 |
212 | return 1;
213 | }
214 |
215 | int deform_conv_backward_input_cuda(
216 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
217 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight,
218 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH,
219 | int dilationH, int dilationW, int deformable_group) {
220 |
221 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 6, input, gradOutput, weight,
222 | offset, columns, gradInput));
223 |
224 | shape_check(state, input, offset, gradOutput, weight, kH, kW, dH, dW, padH,
225 | padW, dilationH, dilationW, deformable_group);
226 |
227 | input = THCudaTensor_newContiguous(state, input);
228 | offset = THCudaTensor_newContiguous(state, offset);
229 | gradOutput = THCudaTensor_newContiguous(state, gradOutput);
230 | weight = THCudaTensor_newContiguous(state, weight);
231 |
232 | int batch = 1;
233 | if (input->nDimension == 3) {
234 | // Force batch
235 | batch = 0;
236 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1],
237 | input->size[2]);
238 | THCudaTensor_resize4d(state, offset, 1, offset->size[0], offset->size[1],
239 | offset->size[2]);
240 | THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0],
241 | gradOutput->size[1], gradOutput->size[2]);
242 | }
243 |
244 | long batchSize = input->size[0];
245 | long nInputPlane = input->size[1];
246 | long inputHeight = input->size[2];
247 | long inputWidth = input->size[3];
248 |
249 | long nOutputPlane = weight->size[0];
250 |
251 | long outputWidth =
252 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
253 | long outputHeight =
254 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
255 |
256 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset");
257 |
258 | THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight,
259 | inputWidth);
260 |
261 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH,
262 | outputHeight * outputWidth);
263 |
264 | THCudaTensor *gradInput_n = THCudaTensor_new(state);
265 | THCudaTensor *gradOffset_n = THCudaTensor_new(state);
266 | THCudaTensor *input_n = THCudaTensor_new(state);
267 | THCudaTensor *offset_n = THCudaTensor_new(state);
268 | THCudaTensor *gradOutput_n = THCudaTensor_new(state);
269 |
270 | for (int elt = 0; elt < batchSize; elt++) {
271 | THCudaTensor_select(state, gradInput_n, gradInput, 0, elt);
272 | THCudaTensor_select(state, gradOffset_n, gradOffset, 0, elt);
273 | THCudaTensor_select(state, input_n, input, 0, elt);
274 | THCudaTensor_select(state, offset_n, offset, 0, elt);
275 | THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
276 |
277 | long m = nInputPlane * kW * kH;
278 | long n = columns->size[1];
279 | long k = nOutputPlane;
280 |
281 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
282 | THCudaTensor_data(state, gradOutput_n), n,
283 | THCudaTensor_data(state, weight), m, 0.0f,
284 | THCudaTensor_data(state, columns), n);
285 |
286 | deformable_col2im_coord(
287 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns),
288 | THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n),
289 | nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
290 | dilationH, dilationW, deformable_group,
291 | THCudaTensor_data(state, gradOffset_n));
292 |
293 | deformable_col2im(
294 | THCState_getCurrentStream(state), THCudaTensor_data(state, columns),
295 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight,
296 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
297 | deformable_group, THCudaTensor_data(state, gradInput_n));
298 | }
299 |
300 | THCudaTensor_free(state, gradInput_n);
301 | THCudaTensor_free(state, gradOffset_n);
302 | THCudaTensor_free(state, input_n);
303 | THCudaTensor_free(state, offset_n);
304 | THCudaTensor_free(state, gradOutput_n);
305 |
306 | if (batch == 0) {
307 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight,
308 | outputWidth);
309 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
310 | THCudaTensor_resize3d(state, gradInput, nInputPlane, inputHeight,
311 | inputWidth);
312 | THCudaTensor_resize3d(state, offset, offset->size[1], offset->size[2],
313 | offset->size[3]);
314 | THCudaTensor_resize3d(state, gradOffset, offset->size[1], offset->size[2],
315 | offset->size[3]);
316 | }
317 |
318 | THCudaTensor_free(state, input);
319 | THCudaTensor_free(state, offset);
320 | THCudaTensor_free(state, gradOutput);
321 | THCudaTensor_free(state, weight);
322 |
323 | return 1;
324 | }
325 |
326 | int deform_conv_backward_parameters_cuda(
327 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
328 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */
329 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH,
330 | int padW, int padH, int dilationH, int dilationW, int deformable_group,
331 | float scale) {
332 |
333 | THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, offset, gradOutput,
334 | gradWeight, columns));
335 |
336 | shape_check(state, input, offset, gradOutput, gradWeight, kH, kW, dH, dW,
337 | padH, padW, dilationH, dilationW, deformable_group);
338 |
339 | input = THCudaTensor_newContiguous(state, input);
340 | offset = THCudaTensor_newContiguous(state, offset);
341 | gradOutput = THCudaTensor_newContiguous(state, gradOutput);
342 |
343 | int batch = 1;
344 | if (input->nDimension == 3) {
345 | // Force batch
346 | batch = 0;
347 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1],
348 | input->size[2]);
349 | THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0],
350 | gradOutput->size[1], gradOutput->size[2]);
351 | }
352 |
353 | long batchSize = input->size[0];
354 | long nInputPlane = input->size[1];
355 | long inputHeight = input->size[2];
356 | long inputWidth = input->size[3];
357 |
358 | long nOutputPlane = gradWeight->size[0];
359 |
360 | long outputWidth =
361 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
362 | long outputHeight =
363 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
364 |
365 | THArgCheck((offset->size[0] == batchSize), 3, "invalid batch size of offset");
366 |
367 | THCudaTensor_resize2d(state, columns, nInputPlane * kW * kH,
368 | outputHeight * outputWidth);
369 |
370 | THCudaTensor *input_n = THCudaTensor_new(state);
371 | THCudaTensor *offset_n = THCudaTensor_new(state);
372 | THCudaTensor *gradOutput_n = THCudaTensor_new(state);
373 |
374 | for (int elt = 0; elt < batchSize; elt++) {
375 | THCudaTensor_select(state, input_n, input, 0, elt);
376 | THCudaTensor_select(state, offset_n, offset, 0, elt);
377 | THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
378 |
379 | deformable_im2col(
380 | THCState_getCurrentStream(state), THCudaTensor_data(state, input_n),
381 | THCudaTensor_data(state, offset_n), nInputPlane, inputHeight,
382 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
383 | deformable_group, THCudaTensor_data(state, columns));
384 |
385 | long m = nOutputPlane;
386 | long n = nInputPlane * kW * kH;
387 | long k = columns->size[1];
388 |
389 | THCudaBlas_Sgemm(state, 't', 'n', n, m, k, scale,
390 | THCudaTensor_data(state, columns), k,
391 | THCudaTensor_data(state, gradOutput_n), k, 1.0f,
392 | THCudaTensor_data(state, gradWeight), n);
393 | }
394 |
395 | THCudaTensor_free(state, input_n);
396 | THCudaTensor_free(state, offset_n);
397 | THCudaTensor_free(state, gradOutput_n);
398 |
399 | if (batch == 0) {
400 | THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight,
401 | outputWidth);
402 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
403 | }
404 |
405 | THCudaTensor_free(state, input);
406 | THCudaTensor_free(state, offset);
407 | THCudaTensor_free(state, gradOutput);
408 | return 1;
409 | }
410 |
--------------------------------------------------------------------------------
/src/deform_conv_cuda.h:
--------------------------------------------------------------------------------
1 | int deform_conv_forward_cuda(THCudaTensor *input,
2 | THCudaTensor *weight, /*THCudaTensor * bias, */
3 | THCudaTensor *offset, THCudaTensor *output,
4 | THCudaTensor *columns, THCudaTensor *ones, int kW,
5 | int kH, int dW, int dH, int padW, int padH,
6 | int dilationH, int dilationW,
7 | int deformable_group);
8 |
9 | int deform_conv_backward_input_cuda(
10 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
11 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight,
12 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH,
13 | int dilationH, int dilationW, int deformable_group);
14 |
15 | int deform_conv_backward_parameters_cuda(
16 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
17 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */
18 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH,
19 | int padW, int padH, int dilationH, int dilationW, int deformable_group,
20 | float scale);
21 |
--------------------------------------------------------------------------------
/src/deform_conv_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "deform_conv_cuda_kernel.h"
2 |
3 | #include
4 |
5 | #define CUDA_KERNEL_LOOP(i, n) \
6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
7 | i += blockDim.x * gridDim.x)
8 |
9 | const int CUDA_NUM_THREADS = 1024;
10 |
11 | inline int GET_BLOCKS(const int N) {
12 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
13 | }
14 |
15 | template
16 | __device__ DType deformable_im2col_bilinear(const DType *bottom_data,
17 | const int data_width,
18 | const int height, const int width,
19 | DType h, DType w) {
20 |
21 | int h_low = floor(h);
22 | int w_low = floor(w);
23 | int h_high;
24 | int w_high;
25 | if (h_low >= height - 1) {
26 | h_high = h_low = height - 1;
27 | h = (DType)h_low;
28 | } else {
29 | h_high = h_low + 1;
30 | }
31 |
32 | if (w_low >= width - 1) {
33 | w_high = w_low = width - 1;
34 | w = (DType)w_low;
35 | } else {
36 | w_high = w_low + 1;
37 | }
38 |
39 | DType lh = h - h_low;
40 | DType lw = w - w_low;
41 | DType hh = 1 - lh, hw = 1 - lw;
42 |
43 | DType v1 = bottom_data[h_low * data_width + w_low];
44 | DType v2 = bottom_data[h_low * data_width + w_high];
45 | DType v3 = bottom_data[h_high * data_width + w_low];
46 | DType v4 = bottom_data[h_high * data_width + w_high];
47 | DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
48 |
49 | DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
50 | return val;
51 | }
52 |
53 | template
54 | __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w,
55 | const int h, const int w, const int height,
56 | const int width) {
57 |
58 | if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
59 | // empty
60 | return 0;
61 | }
62 |
63 | argmax_h = max(argmax_h, (DType)0.0f);
64 | argmax_w = max(argmax_w, (DType)0.0f);
65 |
66 | int argmax_h_low = (int)argmax_h;
67 | int argmax_w_low = (int)argmax_w;
68 | int argmax_h_high;
69 | int argmax_w_high;
70 | if (argmax_h_low >= height - 1) {
71 | argmax_h_high = argmax_h_low = height - 1;
72 | argmax_h = (DType)argmax_h_low;
73 | } else {
74 | argmax_h_high = argmax_h_low + 1;
75 | }
76 | if (argmax_w_low >= width - 1) {
77 | argmax_w_high = argmax_w_low = width - 1;
78 | argmax_w = (DType)argmax_w_low;
79 | } else {
80 | argmax_w_high = argmax_w_low + 1;
81 | }
82 | DType weight = 0;
83 | if (h == argmax_h_low) {
84 | if (w == argmax_w_low) {
85 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
86 | } else if (w == argmax_w_high) {
87 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
88 | }
89 | } else if (h == argmax_h_high) {
90 | if (w == argmax_w_low) {
91 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
92 | } else if (w == argmax_w_high) {
93 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
94 | }
95 | }
96 | return weight;
97 | }
98 |
99 | template
100 | __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w,
101 | const int height, const int width,
102 | const DType *im_data,
103 | const int data_width, const int bp_dir) {
104 |
105 | if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
106 | // empty
107 | return 0;
108 | }
109 |
110 | if (argmax_h < 0)
111 | argmax_h = 0;
112 | if (argmax_w < 0)
113 | argmax_w = 0;
114 |
115 | int argmax_h_low = (int)argmax_h;
116 | int argmax_w_low = (int)argmax_w;
117 | int argmax_h_high;
118 | int argmax_w_high;
119 | if (argmax_h_low >= height - 1) {
120 | argmax_h_high = argmax_h_low = height - 1;
121 | argmax_h = (DType)argmax_h_low;
122 | } else {
123 | argmax_h_high = argmax_h_low + 1;
124 | }
125 | if (argmax_w_low >= width - 1) {
126 | argmax_w_high = argmax_w_low = width - 1;
127 | argmax_w = (DType)argmax_w_low;
128 | } else {
129 | argmax_w_high = argmax_w_low + 1;
130 | }
131 | DType weight = 0;
132 |
133 | if (bp_dir == 0) {
134 | weight += -1 * (argmax_w_low + 1 - argmax_w) *
135 | im_data[argmax_h_low * data_width + argmax_w_low];
136 | weight += -1 * (argmax_w - argmax_w_low) *
137 | im_data[argmax_h_low * data_width + argmax_w_high];
138 | weight += (argmax_w_low + 1 - argmax_w) *
139 | im_data[argmax_h_high * data_width + argmax_w_low];
140 | weight += (argmax_w - argmax_w_low) *
141 | im_data[argmax_h_high * data_width + argmax_w_high];
142 | } else if (bp_dir == 1) {
143 | weight += -1 * (argmax_h_low + 1 - argmax_h) *
144 | im_data[argmax_h_low * data_width + argmax_w_low];
145 | weight += (argmax_h_low + 1 - argmax_h) *
146 | im_data[argmax_h_low * data_width + argmax_w_high];
147 | weight += -1 * (argmax_h - argmax_h_low) *
148 | im_data[argmax_h_high * data_width + argmax_w_low];
149 | weight += (argmax_h - argmax_h_low) *
150 | im_data[argmax_h_high * data_width + argmax_w_high];
151 | }
152 |
153 | return weight;
154 | }
155 |
156 | template
157 | __global__ void deformable_im2col_gpu_kernel(
158 | const int n, const DType *data_im, const DType *data_offset,
159 | const int height, const int width, const int kernel_h, const int kernel_w,
160 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
161 | const int dilation_h, const int dilation_w,
162 | const int channel_per_deformable_group, const int height_col,
163 | const int width_col, DType *data_col) {
164 | CUDA_KERNEL_LOOP(index, n) {
165 | // index index of output matrix
166 | const int w_col = index % width_col;
167 | const int h_col = (index / width_col) % height_col;
168 | const int c_im = (index / width_col) / height_col;
169 | const int c_col = c_im * kernel_h * kernel_w;
170 |
171 | // compute deformable group index
172 | const int deformable_group_index = c_im / channel_per_deformable_group;
173 |
174 | const int h_in = h_col * stride_h - pad_h;
175 | const int w_in = w_col * stride_w - pad_w;
176 | DType *data_col_ptr =
177 | data_col + (c_col * height_col + h_col) * width_col + w_col;
178 | const DType *data_im_ptr = data_im + (c_im * height + h_in) * width + w_in;
179 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 *
180 | kernel_h * kernel_w *
181 | height_col * width_col;
182 |
183 | for (int i = 0; i < kernel_h; ++i) {
184 | for (int j = 0; j < kernel_w; ++j) {
185 | const int data_offset_h_ptr =
186 | ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
187 | const int data_offset_w_ptr =
188 | ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
189 | w_col;
190 | const DType offset_h = data_offset_ptr[data_offset_h_ptr];
191 | const DType offset_w = data_offset_ptr[data_offset_w_ptr];
192 | DType val = static_cast(0);
193 | const DType h_im = h_in + i * dilation_h + offset_h;
194 | const DType w_im = w_in + j * dilation_w + offset_w;
195 | if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
196 | const DType map_h = i * dilation_h + offset_h;
197 | const DType map_w = j * dilation_w + offset_w;
198 | const int cur_height = height - h_in;
199 | const int cur_width = width - w_in;
200 | val = deformable_im2col_bilinear(data_im_ptr, width, cur_height,
201 | cur_width, map_h, map_w);
202 | }
203 | *data_col_ptr = val;
204 | data_col_ptr += height_col * width_col;
205 | }
206 | }
207 | }
208 | }
209 |
210 | template
211 | void deformable_im2col(cudaStream_t stream, const DType *data_im,
212 | const DType *data_offset, const int channels,
213 | const int height, const int width, const int ksize_h,
214 | const int ksize_w, const int pad_h, const int pad_w,
215 | const int stride_h, const int stride_w,
216 | const int dilation_h, const int dilation_w,
217 | const int deformable_group, DType *data_col) {
218 | // We are going to launch channels * height_col * width_col kernels, each
219 | // kernel responsible for copying a single-channel grid.
220 | int height_col =
221 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
222 | int width_col =
223 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
224 | int num_kernels = channels * height_col * width_col;
225 | int channel_per_deformable_group = channels / deformable_group;
226 | // Launch
227 | deformable_im2col_gpu_kernel<<>>(
229 | num_kernels, data_im, data_offset, height, width, ksize_h, ksize_w, pad_h,
230 | pad_w, stride_h, stride_w, dilation_h, dilation_w,
231 | channel_per_deformable_group, height_col, width_col, data_col);
232 |
233 | cudaError_t err = cudaGetLastError();
234 | if (err != cudaSuccess) {
235 | printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
236 | // TODO(BZ) panic
237 | }
238 | }
239 |
240 | template void deformable_im2col(
241 | cudaStream_t stream, const float *data_im, const float *data_offset,
242 | const int channels, const int height, const int width, const int ksize_h,
243 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
244 | const int stride_w, const int dilation_h, const int dilation_w,
245 | const int deformable_group, float *data_col);
246 |
247 | template
248 | __global__ void deformable_col2im_gpu_kernel(
249 | const int n, const DType *data_col, const DType *data_offset,
250 | const int channels, const int height, const int width, const int kernel_h,
251 | const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
252 | const int stride_w, const int dilation_h, const int dilation_w,
253 | const int channel_per_deformable_group, const int height_col,
254 | const int width_col, DType *grad_im) {
255 | CUDA_KERNEL_LOOP(index, n) {
256 | const int j = (index / width_col / height_col) % kernel_w;
257 | const int i = (index / width_col / height_col / kernel_w) % kernel_h;
258 | const int c = index / width_col / height_col / kernel_w / kernel_h;
259 | // compute the start and end of the output
260 |
261 | const int deformable_group_index = c / channel_per_deformable_group;
262 |
263 | int w_out = index % width_col;
264 | int h_out = (index / width_col) % height_col;
265 | int w_in = w_out * stride_w - pad_w;
266 | int h_in = h_out * stride_h - pad_h;
267 |
268 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 *
269 | kernel_h * kernel_w *
270 | height_col * width_col;
271 | const int data_offset_h_ptr =
272 | ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
273 | const int data_offset_w_ptr =
274 | ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
275 | const DType offset_h = data_offset_ptr[data_offset_h_ptr];
276 | const DType offset_w = data_offset_ptr[data_offset_w_ptr];
277 | const DType cur_inv_h_data = h_in + i * dilation_h + offset_h;
278 | const DType cur_inv_w_data = w_in + j * dilation_w + offset_w;
279 |
280 | const DType cur_top_grad = data_col[index];
281 | const int cur_h = (int)cur_inv_h_data;
282 | const int cur_w = (int)cur_inv_w_data;
283 | for (int dy = -2; dy <= 2; dy++) {
284 | for (int dx = -2; dx <= 2; dx++) {
285 | if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
286 | cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
287 | abs(cur_inv_w_data - (cur_w + dx)) < 1) {
288 | int cur_bottom_grad_pos =
289 | (c * height + cur_h + dy) * width + cur_w + dx;
290 | DType weight =
291 | get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
292 | cur_w + dx, height, width);
293 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
294 | }
295 | }
296 | }
297 | }
298 | }
299 |
300 | template
301 | void deformable_col2im(cudaStream_t stream, const DType *data_col,
302 | const DType *data_offset, const int channels,
303 | const int height, const int width, const int ksize_h,
304 | const int ksize_w, const int pad_h, const int pad_w,
305 | const int stride_h, const int stride_w,
306 | const int dilation_h, const int dilation_w,
307 | const int deformable_group, DType *grad_im) {
308 |
309 | int height_col =
310 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
311 | int width_col =
312 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
313 | int num_kernels = channels * ksize_h * ksize_w * height_col * width_col;
314 | int channel_per_deformable_group = channels / deformable_group;
315 | // To avoid involving atomic operations, we will launch one kernel per
316 | // bottom dimension, and then in the kernel add up the top dimensions.
317 | deformable_col2im_gpu_kernel<<>>(
319 | num_kernels, data_col, data_offset, channels, height, width, ksize_h,
320 | ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
321 | channel_per_deformable_group, height_col, width_col, grad_im);
322 |
323 | cudaError_t err = cudaGetLastError();
324 | if (err != cudaSuccess) {
325 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
326 | // TODO(BZ) panic
327 | }
328 | }
329 |
330 | template void deformable_col2im(
331 | cudaStream_t stream, const float *data_col, const float *data_offset,
332 | const int channels, const int height, const int width, const int ksize_h,
333 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
334 | const int stride_w, const int dilation_h, const int dilation_w,
335 | const int deformable_group, float *grad_im);
336 |
337 | template
338 | __global__ void deformable_col2im_coord_gpu_kernel(
339 | const int n, const DType *data_col, const DType *data_im,
340 | const DType *data_offset, const int channels, const int height,
341 | const int width, const int kernel_h, const int kernel_w, const int pad_h,
342 | const int pad_w, const int stride_h, const int stride_w,
343 | const int dilation_h, const int dilation_w,
344 | const int channel_per_deformable_group, const int height_col,
345 | const int width_col, DType *grad_offset) {
346 | CUDA_KERNEL_LOOP(index, n) {
347 | DType val = 0;
348 | int w = index % width_col;
349 | int h = (index / width_col) % height_col;
350 | int c = index / width_col / height_col;
351 | // compute the start and end of the output
352 |
353 | const int deformable_group_index = c / (2 * kernel_h * kernel_w);
354 | const int col_step = kernel_h * kernel_w;
355 | int cnt = 0;
356 | const DType *data_col_ptr = data_col + deformable_group_index *
357 | channel_per_deformable_group *
358 | width_col * height_col;
359 | const DType *data_im_ptr =
360 | data_im + deformable_group_index * channel_per_deformable_group /
361 | kernel_h / kernel_w * height * width;
362 | const DType *data_offset_ptr = data_offset + deformable_group_index * 2 *
363 | kernel_h * kernel_w *
364 | height_col * width_col;
365 |
366 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
367 |
368 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
369 | col_c += col_step) {
370 | const int col_pos = ((col_c * height_col) + h) * width_col + w;
371 | const int bp_dir = offset_c % 2;
372 |
373 | int j = (col_pos / width_col / height_col) % kernel_w;
374 | int i = (col_pos / width_col / height_col / kernel_w) % kernel_h;
375 | int w_out = col_pos % width_col;
376 | int h_out = (col_pos / width_col) % height_col;
377 | int w_in = w_out * stride_w - pad_w;
378 | int h_in = h_out * stride_h - pad_h;
379 | const int data_offset_h_ptr =
380 | (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
381 | const int data_offset_w_ptr =
382 | (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
383 | w_out);
384 | const DType offset_h = data_offset_ptr[data_offset_h_ptr];
385 | const DType offset_w = data_offset_ptr[data_offset_w_ptr];
386 | DType inv_h = h_in + i * dilation_h + offset_h;
387 | DType inv_w = w_in + j * dilation_w + offset_w;
388 | if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) {
389 | inv_h = inv_w = -1;
390 | }
391 | const DType weight = get_coordinate_weight(
392 | inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
393 | width, bp_dir);
394 | val += weight * data_col_ptr[col_pos];
395 | cnt += 1;
396 | }
397 |
398 | grad_offset[index] = val;
399 | }
400 | }
401 |
402 | template
403 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col,
404 | const DType *data_im, const DType *data_offset,
405 | const int channels, const int height,
406 | const int width, const int ksize_h,
407 | const int ksize_w, const int pad_h,
408 | const int pad_w, const int stride_h,
409 | const int stride_w, const int dilation_h,
410 | const int dilation_w, const int deformable_group,
411 | DType *grad_offset) {
412 |
413 | int height_col =
414 | (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
415 | int width_col =
416 | (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
417 | int num_kernels =
418 | height_col * width_col * 2 * ksize_h * ksize_w * deformable_group;
419 | int channel_per_deformable_group =
420 | channels * ksize_h * ksize_w / deformable_group;
421 | // To avoid involving atomic operations, we will launch one kernel per
422 | // bottom dimension, and then in the kernel add up the top dimensions.
423 | deformable_col2im_coord_gpu_kernel<<>>(
425 | num_kernels, data_col, data_im, data_offset, channels, height, width,
426 | ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
427 | dilation_w, channel_per_deformable_group, height_col, width_col,
428 | grad_offset);
429 |
430 | cudaError_t err = cudaGetLastError();
431 | if (err != cudaSuccess) {
432 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
433 | // TODO(BZ) panic
434 | }
435 | }
436 |
437 | template void
438 | deformable_col2im_coord(cudaStream_t stream, const float *data_col,
439 | const float *data_im, const float *data_offset,
440 | const int channels, const int height, const int width,
441 | const int ksize_h, const int ksize_w, const int pad_h,
442 | const int pad_w, const int stride_h, const int stride_w,
443 | const int dilation_h, const int dilation_w,
444 | const int deformable_group, float *grad_offset);
445 |
--------------------------------------------------------------------------------
/src/deform_conv_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | template
2 | void deformable_im2col(cudaStream_t stream, const DType *data_im,
3 | const DType *data_offset, const int channels,
4 | const int height, const int width, const int ksize_h,
5 | const int ksize_w, const int pad_h, const int pad_w,
6 | const int stride_h, const int stride_w,
7 | const int dilation_h, const int dilation_w,
8 | const int deformable_group, DType *data_col);
9 |
10 | template
11 | void deformable_col2im(cudaStream_t stream, const DType *data_col,
12 | const DType *data_offset, const int channels,
13 | const int height, const int width, const int ksize_h,
14 | const int ksize_w, const int pad_h, const int pad_w,
15 | const int stride_h, const int stride_w,
16 | const int dilation_h, const int dilation_w,
17 | const int deformable_group, DType *grad_im);
18 |
19 | template
20 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col,
21 | const DType *data_im, const DType *data_offset,
22 | const int channels, const int height,
23 | const int width, const int ksize_h,
24 | const int ksize_w, const int pad_h,
25 | const int pad_w, const int stride_h,
26 | const int stride_w, const int dilation_h,
27 | const int dilation_w, const int deformable_group,
28 | DType *grad_offset);
29 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | from modules import ConvOffset2d
7 |
8 | num_deformable_groups = 2
9 |
10 | N, inC, inH, inW = 1, 6, 512, 512
11 | outC, outH, outW = 4, 512, 512
12 | kH, kW = 3, 3
13 |
14 | conv = nn.Conv2d(
15 | inC,
16 | num_deformable_groups * 2 * kH * kW,
17 | kernel_size=(kH, kW),
18 | stride=(1, 1),
19 | padding=(1, 1),
20 | bias=False).cuda()
21 |
22 | conv_offset2d = ConvOffset2d(
23 | inC,
24 | outC, (kH, kW),
25 | stride=1,
26 | padding=1,
27 | num_deformable_groups=num_deformable_groups).cuda()
28 |
29 | inputs = Variable(torch.randn(N, inC, inH, inW).cuda())
30 | offset = conv(inputs)
31 | output = conv_offset2d(inputs, offset)
32 | output.backward(output.data)
33 | print(output.size())
34 |
--------------------------------------------------------------------------------
|