├── .gitignore
├── LICENSE
├── README.md
├── deformable_kernels
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── cond_conv.py
│ ├── deform_conv.py
│ └── deform_kernel.py
└── ops
│ ├── __init__.py
│ └── deform_kernel
│ ├── __init__.py
│ ├── csrc
│ ├── filter_sample_depthwise_cuda.cpp
│ ├── filter_sample_depthwise_cuda.h
│ ├── filter_sample_depthwise_cuda_kernel.cu
│ ├── nd_linear_sample_cuda.cpp
│ ├── nd_linear_sample_cuda.h
│ └── nd_linear_sample_cuda_kernel.cu
│ ├── functions
│ ├── __init__.py
│ ├── filter_sample_depthwise.py
│ └── nd_linear_sample.py
│ ├── modules
│ ├── __init__.py
│ └── filter_sample_depthwise.py
│ └── setup.py
└── environment.yml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Global gitignore file
2 | # Adapted from https://github.com/cowboy/dotfiles/blob/master/link/.gitignore_global
3 | # Adatped from https://github.com/wookayin/dotfiles/blob/master/git/gitignore#L14
4 |
5 | # Direnv stuffs #
6 | .direnv
7 | .envrc
8 |
9 | # Compiled source #
10 | *.class
11 | *.dll
12 | *.exe
13 | *.o
14 | *.so
15 | *.pyc
16 | **/__pycache__/
17 |
18 | # Packages #
19 | # It's better to unpack these files and commit the raw source because
20 | # git has its own built in compression methods.
21 | *.7z
22 | *.jar
23 | *.rar
24 | *.zip
25 | *.gz
26 | *.bzip
27 | *.xz
28 | *.lzma
29 |
30 | # packing-only formats
31 | *.iso
32 | *.tar
33 |
34 | # package management formats
35 | *.dmg
36 | *.xpi
37 | *.gem
38 | *.egg
39 | *.egg-info
40 | *.deb
41 | *.rpm
42 |
43 | # Logs and databases #
44 | *.log
45 | *.sqlite
46 |
47 | # OS generated files #
48 | .DS_Store
49 | .Spotlight-V100
50 | .Trashes
51 | ._*
52 |
53 | # Linux
54 | .fuse_hidden*
55 | .nfs*
56 |
57 | # Windows image file caches
58 | Thumbs.db
59 |
60 | # Folder config file
61 | Desktop.ini
62 |
63 | # Vim
64 | .*.s[a-w][a-z]
65 |
66 | # IDE stuffs
67 | .idea/
68 | *.iml
69 | .project
70 | .classpath
71 | .settings/
72 | .ipynb_checkpoints/
73 |
74 | # Tex
75 | *.aux
76 | *.bbl
77 | *.blg
78 | *.brf
79 | *.fdb_latexmk
80 | *.fls
81 | *.log
82 | *.out
83 | *.pdf
84 | *.synctex.gz
85 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hang Gao, Xizhou Zhu, Steve Lin, Jifeng Dai.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deformable Kernels [[ICLR 2020]](https://arxiv.org/abs/1910.02940) [[Website]](https://people.eecs.berkeley.edu/~hangg/deformable-kernels/)
2 |
3 |
4 |
5 | **Deformable Kernels: Adapting Effective Receptive Fields for Object Deformation**
6 | [Hang Gao*](http://people.eecs.berkeley.edu/~hangg/), [Xizhou
7 | Zhu*](https://scholar.google.com/citations?user=02RXI00AAAAJ&hl=en&oi=ao), [Steve Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en&oi=ao), [Jifeng Dai](https://jifengdai.org/).
8 | In [ICLR, 2020](https://arxiv.org/abs/1910.02940).
9 |
10 | This repository contains official implementation of deformable kernels.
11 |
12 | **Table of contents**
13 | 1. [Customized operators for deformable kernels, along with its variants.](#1-customized-operators)
14 | 2. [Instructions to use our operators.](#2-quickstart)
15 | 3. [Results on ImageNet & COCO benchmarks, with pretrained models for
16 | reproduction.](#3-results--pretrained-models)
17 | 5. [Training and evaluation code.](#4-training--evaluation-code)
18 |
19 | ## (0) Getting started
20 |
21 | ### PyTorch
22 | - Get [CUDA 10.1](https://developer.nvidia.com/cuda-10.1-download-archive-base)
23 | installed on your machine.
24 | - Install PyTorch ([pytorch.org](http://pytorch.org)).
25 | - `conda env create -f environment.yml`.
26 |
27 | ### Apex
28 | - Install [Apex](https://github.com/NVIDIA/apex/) from its official repo. This
29 | will require CUDA 10.1 to work with the latest pytorch version (which is
30 | `pytorch=1.3.1` as being tested against). It is used for fast mix-precision
31 | inference and should work out of the box.
32 |
33 | ### Compile our operators
34 | ```bash
35 | # assume at project root
36 | (
37 | cd deformable_kernels/ops/deform_kernel;
38 | pip install -e .;
39 | )
40 | ```
41 |
42 |
43 | ## (1) Customized operators
44 |
45 |
46 |
47 | This repo includes all deformable kernel variants described in our paper, namely:
48 |
49 | - Global Deformable Kernels;
50 | - Local Deformable Kernels;
51 | - Local Deformable Kernels integrating with Deformable Convolutions;
52 |
53 | Instead of learning offsets on image space, we propose to deform and resample
54 | on kernel space. This enables powerful dynamic inference capacity. For more
55 | technical details, please refer to their
56 | [definitions](deformable_kernels/modules/deform_kernel.py).
57 |
58 | We also provide implementations on our rivalries, namely:
59 |
60 | - [Deformable Convolutions](https://arxiv.org/abs/1703.06211);
61 | - [Soft Conditional Computation](https://arxiv.org/abs/1904.04971);
62 |
63 | Please refer to their module definitions under `deformable_kernels/modules` folder.
64 |
65 |
66 | ## (2) Quickstart
67 | The following snippet constructs the deformable kernels we used for our experiments
68 |
69 | ```python
70 | from deformable_kernels.modules import (
71 | GlobalDeformKernel2d,
72 | DeformKernel2d,
73 | DeformKernelConv2d,
74 | )
75 |
76 | # global DK with scope size 2, kernel size 1, stride 1, padding 0, depthwise convolution.
77 | gdk = GlobalDeformKernel2d((2, 2), [inplanes], [inplanes], groups=[inplanes])
78 | # (local) DK with scope size 4, kernel size 3, stride 1, padding 1, depthwise convolution.
79 | dk = DeformKernel2d((4, 4), [inplanes], [inplanes], 3, 1, 1, groups=[inplanes])
80 | # (local) DK integrating with dcn, with kernel & image offsets separately learnt.
81 | dkc = DeformKernelConv2d((4, 4), [inplanes], [inplanes], 3, 1, 1, groups=[inplanes]).
82 | ```
83 |
84 | Note that all of our customized operators only support depthwise convolutions
85 | now, mainly because that efficiently resampling kernels at runtime is
86 | extremely slow if we orthogonally compute over each channel. We are trying to
87 | loose this requirement by iterating our CUDA implementation. Any contribuitions
88 | are welcome!
89 |
90 |
91 | ## (3) Results & pretrained models
92 | Under construction.
93 |
94 |
95 | ## (4) Training & evaluation code
96 | Under construction.
97 |
98 |
99 | ## (A) License
100 | This project is released under the [MIT license](LICENSE).
101 |
102 |
103 | ## (B) Citation & Contact
104 | If you find this repo useful for your research, please consider citing this
105 | bibtex:
106 |
107 | ```tex
108 | @article{gao2019deformable,
109 | title={Deformable Kernels: Adapting Effective Receptive Fields for Object Deformation},
110 | author={Gao, Hang and Zhu, Xizhou and Lin, Steve and Dai, Jifeng},
111 | journal={arXiv preprint arXiv:1910.02940},
112 | year={2019}
113 | }
114 | ```
115 |
116 | Please contact Hang Gao `` and Xizhou Zhu
117 | `` with any comments or feedback.
118 |
--------------------------------------------------------------------------------
/deformable_kernels/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 12/15/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
--------------------------------------------------------------------------------
/deformable_kernels/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 12/26/2019
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | from .cond_conv import CondConv2d
11 | from .deform_conv import DeformConv2d
12 | from .deform_kernel import (
13 | GlobalDeformKernel2d,
14 | LocalDeformKernel2d,
15 | DeformKernel2d,
16 | DeformKernelConv2d,
17 | )
18 |
19 | __all__ = [
20 | 'CondConv2d',
21 | 'DeformConv2d',
22 | 'GlobalDeformKernel2d',
23 | 'LocalDeformKernel2d',
24 | 'DeformKernel2d',
25 | 'DeformKernelConv2d',
26 | ]
27 |
--------------------------------------------------------------------------------
/deformable_kernels/modules/cond_conv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : cond_conv.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 12/25/2019
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import torch
11 |
12 | from apex import amp
13 | from torch import nn
14 |
15 |
16 | class CondConv2d(nn.Module):
17 |
18 | def __init__(self, num_experts, in_channels, out_channels, kernel_size,
19 | stride=1, padding=0, dilation=1, groups=1, bias=True,
20 | padding_mode='zeros'):
21 | super().__init__()
22 | self.num_experts = num_experts
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.kernel_size = kernel_size
26 | self.stride = stride
27 | self.dilation = dilation
28 | self.groups = groups
29 | self.padding_mode = padding_mode
30 | assert not bias
31 |
32 | self.weight = nn.Parameter(
33 | torch.tensor(
34 | num_experts * out_channels,
35 | in_channels // self.groups,
36 | kernel_size,
37 | kernel_size,
38 | )
39 | )
40 | self.fc = nn.Linear(in_channels, num_experts)
41 | self.fc.zero_init = True
42 |
43 | @amp.float_function
44 | def dynaic_inference(self, x, weight):
45 | # TODO(Hang Gao @ 12/26): make sure passing weight to amp is necessary.
46 | n = x.shape[0]
47 |
48 | avg_x = x.mean((2, 3))
49 | gate_x = torch.sigmoid(self.fc(avg_x))
50 |
51 | weight = torch.mm(
52 | gate_x,
53 | self.weight.reshape(self.num_experts, -1)
54 | ).reshape(
55 | n * self.out_channels,
56 | self.in_channels // self.groups,
57 | self.kernel_size,
58 | self.kernel_size,
59 | )
60 | return weight
61 |
62 | def forward(self, x):
63 | n, _, h, w = x.shape
64 | weight = self.dynaic_inference(x, self.weight)
65 |
66 | out = nn.functional.conv2d(
67 | x.reshape(1, n * self.in_channels, h, w),
68 | weight,
69 | stride=self.stride,
70 | padding=self.padding,
71 | dilation=self.dilation,
72 | groups=n*self.groups,
73 | padding_mode=self.padding_mode,
74 | )
75 | out = out.reshape(n, self.out_channels, *out.shape[2:])
76 | return out
77 |
78 | def extra_repr(self):
79 | s = ('{num_experts}, {in_channels}, {out_channels}'
80 | ', kernel_size={kernel_size}, stride={stride}'
81 | ', scale={scale}, zero_point={zero_point}')
82 | if self.padding != (0,) * len(self.padding):
83 | s += ', padding={padding}'
84 | if self.dilation != (1,) * len(self.dilation):
85 | s += ', dilation={dilation}'
86 | if self.groups != 1:
87 | s += ', groups={groups}'
88 | return s.format(**self.__dict__)
89 |
--------------------------------------------------------------------------------
/deformable_kernels/modules/deform_conv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : deform_conv.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | from .deform_kernel import DeformKernelConv2d
11 |
12 | __all__ = ['DeformConv2d']
13 |
14 |
15 | class DeformConv2d(DeformKernelConv2d):
16 | """
17 | Depthwise deformable convolution.
18 | """
19 | def __init__(
20 | self,
21 | in_planes,
22 | out_planes,
23 | kernel_size=1,
24 | stride=1,
25 | padding=0,
26 | dilation=1,
27 | groups=1,
28 | bias=False,
29 | offset_clip=None,
30 | ):
31 | super().__init__(
32 | (kernel_size, kernel_size),
33 | in_planes,
34 | out_planes,
35 | kernel_size,
36 | stride,
37 | padding,
38 | dilation,
39 | groups,
40 | bias,
41 | offset_clip,
42 | )
43 |
--------------------------------------------------------------------------------
/deformable_kernels/modules/deform_kernel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : deform_kernel.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import torch
11 | from torch import nn
12 | from apex import amp
13 |
14 | from ..ops.deform_kernel.functions import nd_linear_sample
15 | from ..ops.deform_kernel.modules import (
16 | SampleDepthwise,
17 | DeformableSampleDepthwise,
18 | )
19 |
20 | __all__ = [
21 | 'GlobalDeformKernel2d',
22 | 'LocalDeformKernel2d',
23 | 'DeformKernel2d',
24 | 'DeformKernelConv2d',
25 | ]
26 |
27 |
28 | class GlobalDeformKernel2d(nn.Module):
29 |
30 | def __init__(
31 | self,
32 | weight_shape,
33 | in_planes,
34 | out_planes,
35 | kernel_size=1,
36 | stride=1,
37 | padding=0,
38 | dilation=1,
39 | groups=1,
40 | bias=False,
41 | ):
42 | super().__init__()
43 | self.kernel_size = kernel_size
44 | self.weight_shape = weight_shape
45 | self.weight_dilate = 1
46 | self.stride = stride
47 | self.padding = padding
48 | self.dilation = dilation
49 | self.out_planes = out_planes
50 | self.in_planes = in_planes
51 | self.group = groups
52 | assert not bias
53 |
54 | self.weight = nn.Parameter(
55 | torch.Tensor(out_planes, in_planes // self.group, *self.weight_shape)
56 | )
57 | self.fc = nn.Linear(
58 | in_planes, kernel_size * kernel_size * len(self.weight_shape)
59 | )
60 | self.fc.zero_init = True
61 |
62 | assert len(self.weight_shape) >= 2
63 |
64 | start_h = (weight_shape[0] - (kernel_size - 1) * self.weight_dilate - 1) / 2.0
65 | start_w = (weight_shape[1] - (kernel_size - 1) * self.weight_dilate - 1) / 2.0
66 | self.fc_bias = []
67 | for h in range(kernel_size):
68 | for w in range(kernel_size):
69 | self.fc_bias += [
70 | start_h + h * self.weight_dilate,
71 | start_w + w * self.weight_dilate,
72 | ]
73 | for i in range(len(self.weight_shape) - 2):
74 | self.fc_bias += [(self.weight_shape[i + 2] - 1) / 2.0]
75 |
76 | @amp.float_function
77 | def dynamic_weight(self, x, weight):
78 | n, c, h, w = x.shape
79 | avg_x = x.view(n, c, -1).mean(2)
80 | coord = self.fc(avg_x) * self.weight_dilate + torch.tensor(
81 | self.fc_bias, dtype=x.dtype, device=x.device
82 | ).unsqueeze(0)
83 | coord = torch.clamp(coord, 0, self.weight_shape[0] - 1)
84 |
85 | weight = weight.view(
86 | self.out_planes * self.in_planes // self.group, *self.weight_shape
87 | )
88 | coord = coord.view(
89 | n * self.kernel_size * self.kernel_size, len(self.weight_shape)
90 | )
91 |
92 | weight_sample = nd_linear_sample(weight, coord).view(
93 | n,
94 | self.kernel_size * self.kernel_size,
95 | self.out_planes * self.in_planes // self.group,
96 | )
97 | weight = weight_sample.transpose(1, 2).reshape(
98 | n * self.out_planes,
99 | self.in_planes // self.group,
100 | self.kernel_size,
101 | self.kernel_size,
102 | )
103 | return weight
104 |
105 | def forward(self, x):
106 | n, c, h, w = x.shape
107 | weight = self.dynamic_weight(x, self.weight)
108 |
109 | out = nn.functional.conv2d(
110 | x.view(1, n * c, h, w),
111 | weight,
112 | stride=self.stride,
113 | padding=self.padding,
114 | dilation=self.dilation,
115 | groups=n * self.group,
116 | )
117 | out = out.view(n, self.out_planes, out.shape[2], out.shape[3])
118 | return out
119 |
120 | def extra_repr(self):
121 | s = (
122 | "{in_planes}, {out_planes}, weight_shape={weight_shape}, "
123 | "kernel_size={kernel_size}, stride={stride}, "
124 | "weight_dilate={weight_dilate}"
125 | )
126 | if self.padding != 0:
127 | s += ", padding={padding}"
128 | if self.dilation != 1:
129 | s += ", dilation={dilation}"
130 | if self.group != 1:
131 | s += ", group={group}"
132 | return s.format(**self.__dict__)
133 |
134 |
135 | class LocalDeformKernel2d(nn.Module):
136 |
137 | def __init__(
138 | self,
139 | weight_shape,
140 | in_planes,
141 | out_planes,
142 | kernel_size=1,
143 | stride=1,
144 | padding=0,
145 | dilation=1,
146 | groups=1,
147 | rotation_groups=1,
148 | bias=False,
149 | rotation_clip=None,
150 | ):
151 | super().__init__()
152 | self.kernel_size = kernel_size
153 | self.weight_shape = weight_shape
154 | self.weight_dilate = 1
155 | self.stride = stride
156 | self.padding = padding
157 | self.dilation = dilation
158 | self.out_planes = out_planes
159 | self.in_planes = in_planes
160 | self.rotation_groups = rotation_groups
161 | self.group = groups
162 | assert not bias
163 | assert len(self.weight_shape) >= 2
164 |
165 | self.rotation_conv = nn.Conv2d(
166 | in_planes, rotation_groups * kernel_size * kernel_size * 2,
167 | kernel_size, stride, padding, dilation, bias=True
168 | )
169 | self.rotation_conv.zero_init = True
170 | self.rotation_clip = rotation_clip
171 |
172 | self.inner_conv = SampleDepthwise(
173 | weight_shape,
174 | in_planes,
175 | out_planes,
176 | kernel_size=kernel_size,
177 | stride=stride,
178 | padding=padding,
179 | dilation=dilation,
180 | groups=groups,
181 | rotation_groups=rotation_groups,
182 | bias=bias,
183 | )
184 |
185 | def _clip_rotation(self, rotation):
186 | if isinstance(self.rotation_clip, tuple):
187 | return rotation.clamp(**self.rotation_clip)
188 | elif self.rotation_clip == 'scope':
189 | if not hasattr(self, 'fc_bias'):
190 | start_h = (self.weight_shape[0] - (self.kernel_size - 1) *
191 | self.weight_dilate - 1) / 2.0
192 | start_w = (self.weight_shape[1] - (self.kernel_size - 1) *
193 | self.weight_dilate - 1) / 2.0
194 | fc_bias = []
195 | for h in range(self.kernel_size):
196 | for w in range(self.kernel_size):
197 | fc_bias += [
198 | start_h + h * self.weight_dilate,
199 | start_w + w * self.weight_dilate,
200 | ]
201 | for i in range(len(self.weight_shape) - 2):
202 | fc_bias += [(self.weight_shape[i + 2] - 1) / 2]
203 | self.fc_bias = rotation.new_tensor(fc_bias) \
204 | .repeat(self.rotation_groups)[None, :, None, None]
205 | coord = (rotation * self.weight_dilate + self.fc_bias).clamp(
206 | 0, self.weight_shape[0] - 1)
207 | return (coord - self.fc_bias) / self.weight_dilate
208 | else:
209 | raise NotImplementedError(
210 | f'Expect rotation_clip to be tuple or "scope", '
211 | f'but get {self.rotation_clip}'
212 | )
213 |
214 | def forward(self, x):
215 | rotation = self.rotation_conv(x)
216 | if self.rotation_clip is not None:
217 | rotation = self._clip_rotation(rotation)
218 | rotation *= self.weight_dilate
219 | out = self.inner_conv(x, rotation)
220 |
221 | return out
222 |
223 |
224 | # refer to local deformable kernel as the default.
225 | DeformKernel2d = LocalDeformKernel2d
226 |
227 |
228 | class DeformKernelConv2d(nn.Module):
229 |
230 | def __init__(
231 | self,
232 | weight_shape,
233 | in_planes,
234 | out_planes,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0,
238 | dilation=1,
239 | groups=1,
240 | bias=False,
241 | offset_clip=None,
242 | ):
243 | super().__init__()
244 | self.kernel_size = kernel_size
245 | self.stride = stride
246 | self.padding = padding
247 | self.dilation = dilation
248 | self.out_planes = out_planes
249 | self.in_planes = in_planes
250 | self.group = groups
251 | assert not bias
252 |
253 | self.offset_conv = nn.Conv2d(
254 | in_planes, kernel_size * kernel_size * 2,
255 | kernel_size, stride, padding, dilation, bias=True
256 | )
257 | self.offset_conv.zero_init = True
258 | self.offset_clip = offset_clip
259 |
260 | self.inner_conv = DeformableSampleDepthwise(
261 | weight_shape,
262 | in_planes,
263 | out_planes,
264 | kernel_size=kernel_size,
265 | stride=stride,
266 | padding=padding,
267 | dilation=dilation,
268 | groups=groups,
269 | bias=bias,
270 | )
271 |
272 | def forward(self, x):
273 | offset = self.offset_conv(x)
274 | if self.offset_clip is not None:
275 | offset = offset.clamp(**self.offset_clip)
276 | offset *= self.dilation
277 |
278 | rotation = None
279 | out = self.inner_conv(x, offset, rotation)
280 |
281 | return out
282 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 12/25/2019
7 | #
8 | # Distributed under terms of the MIT license.
9 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include "filter_sample_depthwise_cuda.h"
7 |
8 | int sample_depthwise_forward_cuda(
9 | at::Tensor input,
10 | at::Tensor rotation,
11 | at::Tensor filter,
12 | at::Tensor output,
13 | int kH,
14 | int kW,
15 | int dH,
16 | int dW,
17 | int padH,
18 | int padW,
19 | int dilationH,
20 | int dilationW,
21 | int scopeH,
22 | int scopeW,
23 | int groupS) {
24 |
25 | input = input.contiguous();
26 | rotation = rotation.contiguous();
27 | filter = filter.contiguous();
28 |
29 | SampleDepthwiseArgs sdw_args;
30 | sdw_args.batch = input.size(0);
31 | sdw_args.channel = input.size(1);
32 | sdw_args.in_height = input.size(2);
33 | sdw_args.in_width = input.size(3);
34 | sdw_args.filter_height = kH;
35 | sdw_args.filter_width = kW;
36 | sdw_args.stride_height = dH;
37 | sdw_args.stride_width = dW;
38 | sdw_args.pad_height = padH;
39 | sdw_args.pad_width = padW;
40 | sdw_args.dilation_height = dilationH;
41 | sdw_args.dilation_width = dilationW;
42 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
43 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
44 | sdw_args.scope_height = scopeH;
45 | sdw_args.scope_width = scopeW;
46 | sdw_args.sampling_group = groupS;
47 |
48 | output = output.view({
49 | sdw_args.batch,
50 | sdw_args.channel,
51 | sdw_args.out_height,
52 | sdw_args.out_width});
53 |
54 | SampleDepthwiseConv2dForward(
55 | input,
56 | rotation,
57 | filter,
58 | sdw_args,
59 | output);
60 |
61 | return 1;
62 | }
63 |
64 | int sample_depthwise_backward_data_cuda(
65 | at::Tensor gradOutput,
66 | at::Tensor input,
67 | at::Tensor rotation,
68 | at::Tensor filter,
69 | at::Tensor gradInput,
70 | int kH,
71 | int kW,
72 | int dH,
73 | int dW,
74 | int padH,
75 | int padW,
76 | int dilationH,
77 | int dilationW,
78 | int scopeH,
79 | int scopeW,
80 | int groupS) {
81 |
82 | gradOutput = gradOutput.contiguous();
83 | rotation = rotation.contiguous();
84 | filter = filter.contiguous();
85 |
86 | SampleDepthwiseArgs sdw_args;
87 | sdw_args.batch = input.size(0);
88 | sdw_args.channel = input.size(1);
89 | sdw_args.in_height = input.size(2);
90 | sdw_args.in_width = input.size(3);
91 | sdw_args.filter_height = kH;
92 | sdw_args.filter_width = kW;
93 | sdw_args.stride_height = dH;
94 | sdw_args.stride_width = dW;
95 | sdw_args.pad_height = padH;
96 | sdw_args.pad_width = padW;
97 | sdw_args.dilation_height = dilationH;
98 | sdw_args.dilation_width = dilationW;
99 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
100 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
101 | sdw_args.scope_height = scopeH;
102 | sdw_args.scope_width = scopeW;
103 | sdw_args.sampling_group = groupS;
104 |
105 | gradInput = gradInput.view({
106 | sdw_args.batch,
107 | sdw_args.channel,
108 | sdw_args.in_height,
109 | sdw_args.in_width});
110 |
111 | SampleDepthwiseConv2dBackwardData(
112 | gradOutput,
113 | rotation,
114 | filter,
115 | sdw_args,
116 | gradInput);
117 |
118 | return 1;
119 | }
120 |
121 | int sample_depthwise_backward_filter_cuda(
122 | at::Tensor gradOutput,
123 | at::Tensor input,
124 | at::Tensor rotation,
125 | at::Tensor filter,
126 | at::Tensor gradFilter,
127 | int kH,
128 | int kW,
129 | int dH,
130 | int dW,
131 | int padH,
132 | int padW,
133 | int dilationH,
134 | int dilationW,
135 | int scopeH,
136 | int scopeW,
137 | int groupS) {
138 |
139 | gradOutput = gradOutput.contiguous();
140 | input = input.contiguous();
141 | rotation = rotation.contiguous();
142 |
143 | SampleDepthwiseArgs sdw_args;
144 | sdw_args.batch = input.size(0);
145 | sdw_args.channel = input.size(1);
146 | sdw_args.in_height = input.size(2);
147 | sdw_args.in_width = input.size(3);
148 | sdw_args.filter_height = kH;
149 | sdw_args.filter_width = kW;
150 | sdw_args.stride_height = dH;
151 | sdw_args.stride_width = dW;
152 | sdw_args.pad_height = padH;
153 | sdw_args.pad_width = padW;
154 | sdw_args.dilation_height = dilationH;
155 | sdw_args.dilation_width = dilationW;
156 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
157 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
158 | sdw_args.scope_height = scopeH;
159 | sdw_args.scope_width = scopeW;
160 | sdw_args.sampling_group = groupS;
161 |
162 | gradFilter = gradFilter.view({
163 | sdw_args.channel,
164 | 1,
165 | sdw_args.scope_height,
166 | sdw_args.scope_width});
167 |
168 | SampleDepthwiseConv2dBackwardFilter(
169 | gradOutput,
170 | input,
171 | rotation,
172 | sdw_args,
173 | gradFilter);
174 |
175 | return 1;
176 | }
177 |
178 | int sample_depthwise_backward_rotation_cuda(
179 | at::Tensor gradOutput,
180 | at::Tensor input,
181 | at::Tensor rotation,
182 | at::Tensor filter,
183 | at::Tensor gradRotation,
184 | int kH,
185 | int kW,
186 | int dH,
187 | int dW,
188 | int padH,
189 | int padW,
190 | int dilationH,
191 | int dilationW,
192 | int scopeH,
193 | int scopeW,
194 | int groupS) {
195 |
196 | gradOutput = gradOutput.contiguous();
197 | input = input.contiguous();
198 | rotation = rotation.contiguous();
199 | filter = filter.contiguous();
200 |
201 | SampleDepthwiseArgs sdw_args;
202 | sdw_args.batch = input.size(0);
203 | sdw_args.channel = input.size(1);
204 | sdw_args.in_height = input.size(2);
205 | sdw_args.in_width = input.size(3);
206 | sdw_args.filter_height = kH;
207 | sdw_args.filter_width = kW;
208 | sdw_args.stride_height = dH;
209 | sdw_args.stride_width = dW;
210 | sdw_args.pad_height = padH;
211 | sdw_args.pad_width = padW;
212 | sdw_args.dilation_height = dilationH;
213 | sdw_args.dilation_width = dilationW;
214 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
215 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
216 | sdw_args.scope_height = scopeH;
217 | sdw_args.scope_width = scopeW;
218 | sdw_args.sampling_group = groupS;
219 |
220 | gradRotation = gradRotation.view({
221 | sdw_args.batch,
222 | groupS * kH * kW * 2,
223 | sdw_args.out_height,
224 | sdw_args.out_width});
225 |
226 | SampleDepthwiseConv2dBackwardRotation(
227 | gradOutput,
228 | input,
229 | rotation,
230 | filter,
231 | sdw_args,
232 | gradRotation);
233 |
234 | return 1;
235 | }
236 |
237 | int deformable_sample_depthwise_forward_cuda(
238 | at::Tensor input,
239 | at::Tensor offset,
240 | at::Tensor rotation,
241 | at::Tensor filter,
242 | at::Tensor output,
243 | int kH,
244 | int kW,
245 | int dH,
246 | int dW,
247 | int padH,
248 | int padW,
249 | int dilationH,
250 | int dilationW,
251 | int scopeH,
252 | int scopeW,
253 | int groupS) {
254 |
255 | input = input.contiguous();
256 | offset = offset.contiguous();
257 | rotation = rotation.contiguous();
258 | filter = filter.contiguous();
259 |
260 | SampleDepthwiseArgs sdw_args;
261 | sdw_args.batch = input.size(0);
262 | sdw_args.channel = input.size(1);
263 | sdw_args.in_height = input.size(2);
264 | sdw_args.in_width = input.size(3);
265 | sdw_args.filter_height = kH;
266 | sdw_args.filter_width = kW;
267 | sdw_args.stride_height = dH;
268 | sdw_args.stride_width = dW;
269 | sdw_args.pad_height = padH;
270 | sdw_args.pad_width = padW;
271 | sdw_args.dilation_height = dilationH;
272 | sdw_args.dilation_width = dilationW;
273 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
274 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
275 | sdw_args.scope_height = scopeH;
276 | sdw_args.scope_width = scopeW;
277 | sdw_args.sampling_group = groupS;
278 |
279 | output = output.view({
280 | sdw_args.batch,
281 | sdw_args.channel,
282 | sdw_args.out_height,
283 | sdw_args.out_width});
284 |
285 | DeformableSampleDepthwiseConv2dForward(
286 | input,
287 | offset,
288 | rotation,
289 | filter,
290 | sdw_args,
291 | output);
292 |
293 | return 1;
294 | }
295 |
296 | int deformable_sample_depthwise_backward_data_cuda(
297 | at::Tensor gradOutput,
298 | at::Tensor input,
299 | at::Tensor offset,
300 | at::Tensor rotation,
301 | at::Tensor filter,
302 | at::Tensor gradInput,
303 | int kH,
304 | int kW,
305 | int dH,
306 | int dW,
307 | int padH,
308 | int padW,
309 | int dilationH,
310 | int dilationW,
311 | int scopeH,
312 | int scopeW,
313 | int groupS) {
314 |
315 | gradOutput = gradOutput.contiguous();
316 | offset = offset.contiguous();
317 | rotation = rotation.contiguous();
318 | filter = filter.contiguous();
319 |
320 | SampleDepthwiseArgs sdw_args;
321 | sdw_args.batch = input.size(0);
322 | sdw_args.channel = input.size(1);
323 | sdw_args.in_height = input.size(2);
324 | sdw_args.in_width = input.size(3);
325 | sdw_args.filter_height = kH;
326 | sdw_args.filter_width = kW;
327 | sdw_args.stride_height = dH;
328 | sdw_args.stride_width = dW;
329 | sdw_args.pad_height = padH;
330 | sdw_args.pad_width = padW;
331 | sdw_args.dilation_height = dilationH;
332 | sdw_args.dilation_width = dilationW;
333 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
334 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
335 | sdw_args.scope_height = scopeH;
336 | sdw_args.scope_width = scopeW;
337 | sdw_args.sampling_group = groupS;
338 |
339 | gradInput = gradInput.view({
340 | sdw_args.batch,
341 | sdw_args.channel,
342 | sdw_args.in_height,
343 | sdw_args.in_width});
344 |
345 | DeformableSampleDepthwiseConv2dBackwardData(
346 | gradOutput,
347 | offset,
348 | rotation,
349 | filter,
350 | sdw_args,
351 | gradInput);
352 |
353 | return 1;
354 | }
355 |
356 | int deformable_sample_depthwise_backward_filter_cuda(
357 | at::Tensor gradOutput,
358 | at::Tensor input,
359 | at::Tensor offset,
360 | at::Tensor rotation,
361 | at::Tensor filter,
362 | at::Tensor gradFilter,
363 | int kH,
364 | int kW,
365 | int dH,
366 | int dW,
367 | int padH,
368 | int padW,
369 | int dilationH,
370 | int dilationW,
371 | int scopeH,
372 | int scopeW,
373 | int groupS) {
374 |
375 | gradOutput = gradOutput.contiguous();
376 | input = input.contiguous();
377 | offset = offset.contiguous();
378 | rotation = rotation.contiguous();
379 |
380 | SampleDepthwiseArgs sdw_args;
381 | sdw_args.batch = input.size(0);
382 | sdw_args.channel = input.size(1);
383 | sdw_args.in_height = input.size(2);
384 | sdw_args.in_width = input.size(3);
385 | sdw_args.filter_height = kH;
386 | sdw_args.filter_width = kW;
387 | sdw_args.stride_height = dH;
388 | sdw_args.stride_width = dW;
389 | sdw_args.pad_height = padH;
390 | sdw_args.pad_width = padW;
391 | sdw_args.dilation_height = dilationH;
392 | sdw_args.dilation_width = dilationW;
393 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
394 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
395 | sdw_args.scope_height = scopeH;
396 | sdw_args.scope_width = scopeW;
397 | sdw_args.sampling_group = groupS;
398 |
399 | gradFilter = gradFilter.view({
400 | sdw_args.channel,
401 | 1,
402 | sdw_args.scope_height,
403 | sdw_args.scope_width});
404 |
405 | DeformableSampleDepthwiseConv2dBackwardFilter(
406 | gradOutput,
407 | input,
408 | offset,
409 | rotation,
410 | sdw_args,
411 | gradFilter);
412 |
413 | return 1;
414 | }
415 |
416 | int deformable_sample_depthwise_backward_offset_cuda(
417 | at::Tensor gradOutput,
418 | at::Tensor input,
419 | at::Tensor offset,
420 | at::Tensor rotation,
421 | at::Tensor filter,
422 | at::Tensor gradOffset,
423 | int kH,
424 | int kW,
425 | int dH,
426 | int dW,
427 | int padH,
428 | int padW,
429 | int dilationH,
430 | int dilationW,
431 | int scopeH,
432 | int scopeW,
433 | int groupS) {
434 |
435 | gradOutput = gradOutput.contiguous();
436 | input = input.contiguous();
437 | offset = offset.contiguous();
438 | rotation = rotation.contiguous();
439 | filter = filter.contiguous();
440 |
441 | SampleDepthwiseArgs sdw_args;
442 | sdw_args.batch = input.size(0);
443 | sdw_args.channel = input.size(1);
444 | sdw_args.in_height = input.size(2);
445 | sdw_args.in_width = input.size(3);
446 | sdw_args.filter_height = kH;
447 | sdw_args.filter_width = kW;
448 | sdw_args.stride_height = dH;
449 | sdw_args.stride_width = dW;
450 | sdw_args.pad_height = padH;
451 | sdw_args.pad_width = padW;
452 | sdw_args.dilation_height = dilationH;
453 | sdw_args.dilation_width = dilationW;
454 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
455 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
456 | sdw_args.scope_height = scopeH;
457 | sdw_args.scope_width = scopeW;
458 | sdw_args.sampling_group = groupS;
459 |
460 | gradOffset = gradOffset.view({
461 | sdw_args.batch,
462 | kH * kW * 2,
463 | sdw_args.out_height,
464 | sdw_args.out_width});
465 |
466 | DeformableSampleDepthwiseConv2dBackwardOffset(
467 | gradOutput,
468 | input,
469 | offset,
470 | rotation,
471 | filter,
472 | sdw_args,
473 | gradOffset);
474 |
475 | return 1;
476 | }
477 |
478 | int deformable_sample_depthwise_backward_rotation_cuda(
479 | at::Tensor gradOutput,
480 | at::Tensor input,
481 | at::Tensor offset,
482 | at::Tensor rotation,
483 | at::Tensor filter,
484 | at::Tensor gradRotation,
485 | int kH,
486 | int kW,
487 | int dH,
488 | int dW,
489 | int padH,
490 | int padW,
491 | int dilationH,
492 | int dilationW,
493 | int scopeH,
494 | int scopeW,
495 | int groupS) {
496 |
497 | gradOutput = gradOutput.contiguous();
498 | input = input.contiguous();
499 | offset = offset.contiguous();
500 | rotation = rotation.contiguous();
501 | filter = filter.contiguous();
502 |
503 | SampleDepthwiseArgs sdw_args;
504 | sdw_args.batch = input.size(0);
505 | sdw_args.channel = input.size(1);
506 | sdw_args.in_height = input.size(2);
507 | sdw_args.in_width = input.size(3);
508 | sdw_args.filter_height = kH;
509 | sdw_args.filter_width = kW;
510 | sdw_args.stride_height = dH;
511 | sdw_args.stride_width = dW;
512 | sdw_args.pad_height = padH;
513 | sdw_args.pad_width = padW;
514 | sdw_args.dilation_height = dilationH;
515 | sdw_args.dilation_width = dilationW;
516 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
517 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
518 | sdw_args.scope_height = scopeH;
519 | sdw_args.scope_width = scopeW;
520 | sdw_args.sampling_group = groupS;
521 |
522 | gradRotation = gradRotation.view({
523 | sdw_args.batch,
524 | groupS * kH * kW * 2,
525 | sdw_args.out_height,
526 | sdw_args.out_width});
527 |
528 | DeformableSampleDepthwiseConv2dBackwardRotation(
529 | gradOutput,
530 | input,
531 | offset,
532 | rotation,
533 | filter,
534 | sdw_args,
535 | gradRotation);
536 |
537 | return 1;
538 | }
539 |
540 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
541 | m.def("sample_depthwise_forward_cuda",
542 | &sample_depthwise_forward_cuda,
543 | "sample_depthwise_forward (CUDA)");
544 | m.def("sample_depthwise_backward_data_cuda",
545 | &sample_depthwise_backward_data_cuda,
546 | "sample_depthwise_backward_data (CUDA)");
547 | m.def("sample_depthwise_backward_filter_cuda",
548 | &sample_depthwise_backward_filter_cuda,
549 | "sample_depthwise_backward_filter (CUDA)");
550 | m.def("sample_depthwise_backward_rotation_cuda",
551 | &sample_depthwise_backward_rotation_cuda,
552 | "sample_depthwise_backward_rotation (CUDA)");
553 |
554 | m.def("deformable_sample_depthwise_forward_cuda",
555 | &deformable_sample_depthwise_forward_cuda,
556 | "deformable_sample_depthwise_forward (CUDA)");
557 | m.def("deformable_sample_depthwise_backward_data_cuda",
558 | &deformable_sample_depthwise_backward_data_cuda,
559 | "deformable_sample_depthwise_backward_data (CUDA)");
560 | m.def("deformable_sample_depthwise_backward_filter_cuda",
561 | &deformable_sample_depthwise_backward_filter_cuda,
562 | "deformable_sample_depthwise_backward_filter (CUDA)");
563 | m.def("deformable_sample_depthwise_backward_offset_cuda",
564 | &deformable_sample_depthwise_backward_offset_cuda,
565 | "deformable_sample_depthwise_backward_offset (CUDA)");
566 | m.def("deformable_sample_depthwise_backward_rotation_cuda",
567 | &deformable_sample_depthwise_backward_rotation_cuda,
568 | "deformable_sample_depthwise_backward_rotation (CUDA)");
569 | }
570 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda.h:
--------------------------------------------------------------------------------
1 | struct DepthwiseArgs {
2 | // Input layer dimensions
3 | int batch;
4 | int channel;
5 | int in_height;
6 | int in_width;
7 |
8 | // Weight layer dimensions
9 | int filter_height;
10 | int filter_width;
11 | int stride_height;
12 | int stride_width;
13 | int pad_height;
14 | int pad_width;
15 | int dilation_height;
16 | int dilation_width;
17 |
18 | // Output layer dimensions
19 | int out_height;
20 | int out_width;
21 | };
22 |
23 | struct SampleDepthwiseArgs : DepthwiseArgs {
24 | // Weight layer dimensions
25 | int scope_height;
26 | int scope_width;
27 | int sampling_group;
28 | };
29 |
30 | void SampleDepthwiseConv2dForward(
31 | const at::Tensor input,
32 | const at::Tensor rotation_ratio,
33 | const at::Tensor filter,
34 | const SampleDepthwiseArgs args,
35 | at::Tensor output);
36 |
37 | void DeformableSampleDepthwiseConv2dForward(
38 | const at::Tensor input,
39 | const at::Tensor offset,
40 | const at::Tensor rotation_ratio,
41 | const at::Tensor filter,
42 | const SampleDepthwiseArgs args,
43 | at::Tensor output);
44 |
45 | void SampleDepthwiseConv2dBackwardData(
46 | const at::Tensor out_grad,
47 | const at::Tensor rotation_ratio,
48 | const at::Tensor filter,
49 | const SampleDepthwiseArgs args,
50 | at::Tensor in_grad);
51 |
52 | void DeformableSampleDepthwiseConv2dBackwardData(
53 | const at::Tensor out_grad,
54 | const at::Tensor offset,
55 | const at::Tensor rotation_ratio,
56 | const at::Tensor filter,
57 | const SampleDepthwiseArgs args,
58 | at::Tensor in_grad);
59 |
60 | void SampleDepthwiseConv2dBackwardFilter(
61 | const at::Tensor out_grad,
62 | const at::Tensor input,
63 | const at::Tensor rotation_ratio,
64 | const SampleDepthwiseArgs args,
65 | at::Tensor filter_grad);
66 |
67 | void DeformableSampleDepthwiseConv2dBackwardFilter(
68 | const at::Tensor out_grad,
69 | const at::Tensor input,
70 | const at::Tensor offset,
71 | const at::Tensor rotation_ratio,
72 | const SampleDepthwiseArgs args,
73 | at::Tensor filter_grad);
74 |
75 | void DeformableSampleDepthwiseConv2dBackwardOffset(
76 | const at::Tensor out_grad,
77 | const at::Tensor input,
78 | const at::Tensor offset,
79 | const at::Tensor rotation_ratio,
80 | const at::Tensor filter,
81 | const SampleDepthwiseArgs args,
82 | at::Tensor offset_grad);
83 |
84 | void SampleDepthwiseConv2dBackwardRotation(
85 | const at::Tensor out_grad,
86 | const at::Tensor input,
87 | const at::Tensor rotation_ratio,
88 | const at::Tensor filter,
89 | const SampleDepthwiseArgs args,
90 | at::Tensor rotation_grad);
91 |
92 | void DeformableSampleDepthwiseConv2dBackwardRotation(
93 | const at::Tensor out_grad,
94 | const at::Tensor input,
95 | const at::Tensor offset,
96 | const at::Tensor rotation_ratio,
97 | const at::Tensor filter,
98 | const SampleDepthwiseArgs args,
99 | at::Tensor rotation_grad);
100 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include "filter_sample_depthwise_cuda.h"
9 |
10 | using namespace at;
11 |
12 | #define CUDA_KERNEL_LOOP(i, n) \
13 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
14 | i += blockDim.x * gridDim.x)
15 |
16 | const int CUDA_NUM_THREADS = 1024;
17 | const int kMaxGridNum = 65535;
18 |
19 | inline int GET_BLOCKS(const int N) {
20 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
21 | }
22 |
23 | #if !defined(_MSC_VER)
24 | #define CUDA_UNROLL _Pragma("unroll")
25 | #define CUDA_NOUNROLL _Pragma("nounroll")
26 | #else
27 | #define CUDA_UNROLL
28 | #define CUDA_NOUNROLL
29 | #endif
30 |
31 | template
32 | __device__ inline scalar_t ldg(const scalar_t* address) {
33 | #if __CUDA_ARCH__ >= 350
34 | return __ldg(address);
35 | #else
36 | return *address;
37 | #endif
38 | }
39 |
40 | template
41 | inline scalar_t __device__ CudaMax(scalar_t a, scalar_t b) {
42 | return a > b ? a : b;
43 | }
44 |
45 | template
46 | inline scalar_t __device__ CudaMin(scalar_t a, scalar_t b) {
47 | return a < b ? a : b;
48 | }
49 |
50 | // assuming h, w is remainder of division, thus h in [0, height), w in [0, width)
51 | template
52 | __device__ scalar_t planar_bilinear(
53 | const scalar_t *data,
54 | const int height,
55 | const int width,
56 | const scalar_t h,
57 | const scalar_t w) {
58 |
59 | if (h > -1 && w > -1 && h < height && w < width) {
60 | int h_low = floor(h);
61 | int w_low = floor(w);
62 |
63 | int h_high = h_low + 1;
64 | int w_high = w_low + 1;
65 |
66 | const scalar_t lh = h - h_low;
67 | const scalar_t lw = w - w_low;
68 | const scalar_t hh = 1 - lh, hw = 1 - lw;
69 |
70 | scalar_t val = 0;
71 | if (h_low >= 0 && w_low >= 0)
72 | val += hh * hw * ldg(data + h_low * width + w_low);
73 | if (h_low >=0 && w_high <= width - 1)
74 | val += hh * lw * ldg(data + h_low * width + w_high);
75 | if (h_high <= height - 1 && w_low >= 0)
76 | val += lh * hw * ldg(data + h_high * width + w_low);
77 | if (h_high <= height - 1 && w_high <= width - 1)
78 | val += lh * lw * ldg(data + h_high * width + w_high);
79 | return val;
80 | } else {
81 | return 0;
82 | }
83 | }
84 |
85 | template
86 | __device__ void planar_bilinear_backward_data(
87 | const scalar_t partial_sum,
88 | const int height,
89 | const int width,
90 | const scalar_t h,
91 | const scalar_t w,
92 | scalar_t* filter_gradient) {
93 |
94 | if (h > -1 && w > -1 && h < height && w < width) {
95 | int h_low = floor(h);
96 | int w_low = floor(w);
97 |
98 | int h_high = h_low + 1;
99 | int w_high = w_low + 1;
100 |
101 | const scalar_t lh = h - h_low;
102 | const scalar_t lw = w - w_low;
103 | const scalar_t hh = 1 - lh, hw = 1 - lw;
104 |
105 | if (h_low >= 0 && w_low >= 0)
106 | atomicAdd(filter_gradient + h_low * width + w_low, hh * hw * partial_sum);
107 | if (h_low >=0 && w_high <= width - 1)
108 | atomicAdd(filter_gradient + h_low * width + w_high, hh * lw * partial_sum);
109 | if (h_high <= height - 1 && w_low >= 0)
110 | atomicAdd(filter_gradient + h_high * width + w_low, lh * hw * partial_sum);
111 | if (h_high <= height - 1 && w_high <= width - 1)
112 | atomicAdd(filter_gradient + h_high * width + w_high, lh * lw * partial_sum);
113 | }
114 | }
115 |
116 | template
117 | __device__ scalar_t planar_bilinear_backward_coord(
118 | const scalar_t partial_sum,
119 | const scalar_t* filter,
120 | const int height,
121 | const int width,
122 | const scalar_t h,
123 | const scalar_t w,
124 | const int bp_dir) {
125 |
126 | if (h > -1 && w > -1 && h < height && w < width) {
127 | int h_low = floor(h);
128 | int w_low = floor(w);
129 |
130 | int h_high = h_low + 1;
131 | int w_high = w_low + 1;
132 |
133 | const scalar_t lh = h - h_low;
134 | const scalar_t lw = w - w_low;
135 | const scalar_t hh = 1 - lh, hw = 1 - lw;
136 |
137 | if (bp_dir == 0) {
138 | scalar_t gradient_h = 0;
139 | if (h_low >= 0 && w_low >= 0)
140 | gradient_h -= hw * partial_sum * ldg(filter + h_low * width + w_low);
141 | if (h_low >=0 && w_high <= width - 1)
142 | gradient_h -= lw * partial_sum * ldg(filter + h_low * width + w_high);
143 | if (h_high <= height - 1 && w_low >= 0)
144 | gradient_h += hw * partial_sum * ldg(filter + h_high * width + w_low);
145 | if (h_high <= height - 1 && w_high <= width - 1)
146 | gradient_h += lw * partial_sum * ldg(filter + h_high * width + w_high);
147 | return gradient_h;
148 | } else {
149 | scalar_t gradient_w = 0;
150 | if (h_low >= 0 && w_low >= 0)
151 | gradient_w -= hh * partial_sum * ldg(filter + h_low * width + w_low);
152 | if (h_low >=0 && w_high <= width - 1)
153 | gradient_w += hh * partial_sum * ldg(filter + h_low * width + w_high);
154 | if (h_high <= height - 1 && w_low >= 0)
155 | gradient_w -= lh * partial_sum * ldg(filter + h_high * width + w_low);
156 | if (h_high <= height - 1 && w_high <= width - 1)
157 | gradient_w += lh * partial_sum * ldg(filter + h_high * width + w_high);
158 | return gradient_w;
159 | }
160 | } else {
161 | return 0;
162 | }
163 | }
164 |
165 | template
166 | __device__ scalar_t deformable_im2col_bilinear(
167 | const scalar_t *bottom_data,
168 | const int height,
169 | const int width,
170 | scalar_t h,
171 | scalar_t w) {
172 |
173 | int h_low = floor(h);
174 | int w_low = floor(w);
175 | int h_high = h_low + 1;
176 | int w_high = w_low + 1;
177 |
178 | scalar_t lh = h - h_low;
179 | scalar_t lw = w - w_low;
180 | scalar_t hh = 1 - lh, hw = 1 - lw;
181 |
182 | scalar_t v1 = 0;
183 | if (h_low >= 0 && w_low >= 0)
184 | v1 = ldg(bottom_data + h_low * width + w_low);
185 | scalar_t v2 = 0;
186 | if (h_low >=0 && w_high <= width - 1)
187 | v2 = ldg(bottom_data + h_low * width + w_high);
188 | scalar_t v3 = 0;
189 | if (h_high <= height - 1 && w_low >= 0)
190 | v3 = ldg(bottom_data + h_high * width + w_low);
191 | scalar_t v4 = 0;
192 | if (h_high <= height - 1 && w_high <= width - 1)
193 | v4 = ldg(bottom_data + h_high * width + w_high);
194 |
195 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
196 |
197 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
198 | return val;
199 | }
200 |
201 | template
202 | __device__ void deformable_im2col_bilinear_backward(
203 | const scalar_t partial_sum,
204 | const scalar_t h,
205 | const scalar_t w,
206 | const int height,
207 | const int width,
208 | scalar_t* data_gradient) {
209 |
210 | int h_low = floor(h);
211 | int w_low = floor(w);
212 | int h_high = h_low + 1;
213 | int w_high = w_low + 1;
214 |
215 | scalar_t lh = h - h_low;
216 | scalar_t lw = w - w_low;
217 | scalar_t hh = 1 - lh, hw = 1 - lw;
218 |
219 | if (h_low >= 0 && w_low >= 0)
220 | atomicAdd(data_gradient + h_low * width + w_low, hh * hw * partial_sum);
221 | if (h_low >=0 && w_high <= width - 1)
222 | atomicAdd(data_gradient + h_low * width + w_high, hh * lw * partial_sum);
223 | if (h_high <= height - 1 && w_low >= 0)
224 | atomicAdd(data_gradient + h_high * width + w_low, lh * hw * partial_sum);
225 | if (h_high <= height - 1 && w_high <= width - 1)
226 | atomicAdd(data_gradient + h_high * width + w_high, lh * lw * partial_sum);
227 |
228 | return;
229 | }
230 |
231 | template
232 | __device__ scalar_t get_coordinate_weight(
233 | scalar_t argmax_h,
234 | scalar_t argmax_w,
235 | const int height,
236 | const int width,
237 | const scalar_t *im_data,
238 | const int bp_dir) {
239 |
240 | int argmax_h_low = floor(argmax_h);
241 | int argmax_w_low = floor(argmax_w);
242 | int argmax_h_high = argmax_h_low + 1;
243 | int argmax_w_high = argmax_w_low + 1;
244 |
245 | scalar_t weight = 0;
246 |
247 | if (bp_dir == 0) {
248 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
249 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * width + argmax_w_low];
250 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
251 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * width + argmax_w_high];
252 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
253 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * width + argmax_w_low];
254 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
255 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * width + argmax_w_high];
256 | } else if (bp_dir == 1) {
257 | if (argmax_h_low >= 0 && argmax_w_low >= 0)
258 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * width + argmax_w_low];
259 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
260 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * width + argmax_w_high];
261 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
262 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * width + argmax_w_low];
263 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
264 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * width + argmax_w_high];
265 | }
266 |
267 | return weight;
268 | }
269 |
270 | template
271 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dForwardKernel(
272 | int n,
273 | const scalar_t* input,
274 | const scalar_t* rotation_ratio,
275 | const scalar_t* filter,
276 | const SampleDepthwiseArgs args,
277 | scalar_t* output) {
278 |
279 | const int channel = args.channel;
280 | const int in_height = args.in_height;
281 | const int in_width = args.in_width;
282 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
283 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
284 | const int stride_height = args.stride_height;
285 | const int stride_width = args.stride_width;
286 | const int pad_height = args.pad_height;
287 | const int pad_width = args.pad_width;
288 | const int dilation_height = args.dilation_height;
289 | const int dilation_width = args.dilation_width;
290 | const int out_height = args.out_height;
291 | const int out_width = args.out_width;
292 |
293 | const int scope_height = args.scope_height;
294 | const int scope_width = args.scope_width;
295 | const int sampling_group = args.sampling_group;
296 |
297 | CUDA_KERNEL_LOOP(thread_id, n) {
298 | const int out_w = thread_id % out_width;
299 | const int out_h = (thread_id / out_width) % out_height;
300 | const int out_c = (thread_id / out_width / out_height) % channel;
301 | const int out_b = thread_id / out_width / out_height / channel;
302 | const int in_c = out_c;
303 |
304 | const int input_offset_temp =
305 | (out_b * channel + in_c) * (in_height * in_width);
306 |
307 | const int group_id = in_c % sampling_group;
308 | const int rotation_offset_temp =
309 | (out_b * sampling_group + group_id) * (filter_height * filter_width * 2) *
310 | out_height * out_width + (out_h * out_width + out_w);
311 | const int filter_offset_temp = in_c * scope_height * scope_width;
312 |
313 | // Finally, we can iterate over the spatial dimensions and perform the
314 | // convolution, writing into the output at the end.
315 | const int input_h_start = out_h * stride_height - pad_height;
316 | const int input_w_start = out_w * stride_width - pad_width;
317 | const int input_h_end = input_h_start + (filter_height - 1) * dilation_height;
318 | const int input_w_end = input_w_start + (filter_width - 1) * dilation_width;
319 |
320 | scalar_t sum = 0;
321 | if (input_h_start >= 0 && input_w_start >= 0 &&
322 | input_h_end < in_height && input_w_end < in_width) {
323 | // Loop that doesn't need to check for boundary conditions.
324 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
325 | const int in_h = input_h_start + f_h * dilation_height;
326 |
327 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
328 | const int in_w = input_w_start + f_w * dilation_width;
329 | const int input_offset = (input_offset_temp) + (in_h * in_width) + in_w;
330 |
331 | const int rotation_offset_fhw = rotation_offset_temp +
332 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
333 | const scalar_t rotation_ratio_h =
334 | ldg(rotation_ratio + rotation_offset_fhw);
335 | const scalar_t rotation_ratio_w =
336 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width);
337 |
338 | const scalar_t filter_h = f_h + rotation_ratio_h +
339 | (scope_height - filter_height) / 2.0;
340 | const scalar_t filter_w = f_w + rotation_ratio_w +
341 | (scope_width - filter_width) / 2.0;
342 | sum += ldg(input + input_offset) * planar_bilinear(
343 | filter + filter_offset_temp,
344 | scope_height,
345 | scope_width,
346 | filter_h,
347 | filter_w);
348 | }
349 | }
350 | } else {
351 | // Loop that needs to check for boundary conditions.
352 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
353 | const int in_h = input_h_start + f_h * dilation_height;
354 |
355 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
356 | const int in_w = input_w_start + f_w * dilation_width;
357 |
358 | // NOTE(Hang Gao @ 07/25): how much runtime will it save?
359 | if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) {
360 | const int input_offset = input_offset_temp + (in_h * in_width) + in_w;
361 |
362 | const int rotation_offset_fhw = rotation_offset_temp +
363 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
364 | const scalar_t rotation_ratio_h =
365 | ldg(rotation_ratio + rotation_offset_fhw);
366 | const scalar_t rotation_ratio_w =
367 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width);
368 |
369 | const scalar_t filter_h = f_h + rotation_ratio_h +
370 | (scope_height - filter_height) / 2.0;
371 | const scalar_t filter_w = f_w + rotation_ratio_w +
372 | (scope_width - filter_width) / 2.0;
373 | sum += ldg(input + input_offset) * planar_bilinear(
374 | filter + filter_offset_temp,
375 | scope_height,
376 | scope_width,
377 | filter_h,
378 | filter_w);
379 | }
380 | }
381 |
382 | }
383 | }
384 | output[thread_id] = sum;
385 | }
386 | }
387 |
388 | void SampleDepthwiseConv2dForward(
389 | const at::Tensor input,
390 | const at::Tensor rotation_ratio,
391 | const at::Tensor filter,
392 | const SampleDepthwiseArgs args,
393 | at::Tensor output) {
394 |
395 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width;
396 |
397 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
398 | input.type(), "SampleDepthwiseConv2dForward_GPU", ([&] {
399 | const scalar_t *input_ = input.data();
400 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
401 | const scalar_t *filter_ = filter.data();
402 | scalar_t *output_ = output.data();
403 |
404 |
405 | if (args.filter_height == 3 && args.filter_width == 3) {
406 | SampleDepthwiseConv2dForwardKernel
407 | <<>>(
408 | num_kernels,
409 | input_,
410 | rotation_ratio_,
411 | filter_,
412 | args,
413 | output_);
414 | } else {
415 | SampleDepthwiseConv2dForwardKernel
416 | <<>>(
417 | num_kernels,
418 | input_,
419 | rotation_ratio_,
420 | filter_,
421 | args,
422 | output_);
423 | }
424 |
425 | }));
426 |
427 | cudaError_t err = cudaGetLastError();
428 | if (err != cudaSuccess) {
429 | printf("error in SampleDepthwiseConv2dForwardKernel: %s\n", cudaGetErrorString(err));
430 | }
431 | }
432 |
433 | template
434 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dForwardKernel(
435 | int n,
436 | const scalar_t* input,
437 | const scalar_t* offset,
438 | const scalar_t* rotation_ratio,
439 | const scalar_t* filter,
440 | const SampleDepthwiseArgs args,
441 | scalar_t* output) {
442 |
443 | const int channel = args.channel;
444 | const int in_height = args.in_height;
445 | const int in_width = args.in_width;
446 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
447 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
448 | const int stride_height = args.stride_height;
449 | const int stride_width = args.stride_width;
450 | const int pad_height = args.pad_height;
451 | const int pad_width = args.pad_width;
452 | const int dilation_height = args.dilation_height;
453 | const int dilation_width = args.dilation_width;
454 | const int out_height = args.out_height;
455 | const int out_width = args.out_width;
456 |
457 | const int scope_height = args.scope_height;
458 | const int scope_width = args.scope_width;
459 | const int sampling_group = args.sampling_group;
460 |
461 | CUDA_KERNEL_LOOP(thread_id, n) {
462 | const int out_w = thread_id % out_width;
463 | const int out_h = (thread_id / out_width) % out_height;
464 | const int out_c = (thread_id / out_width / out_height) % channel;
465 | const int out_b = thread_id / out_width / out_height / channel;
466 | const int in_c = out_c;
467 |
468 | const int input_offset_temp =
469 | (out_b * channel + in_c) * (in_height * in_width);
470 | const int deformation_offset_temp =
471 | out_b * (filter_height * filter_width * 2) * out_height * out_width +
472 | (out_h * out_width + out_w);
473 | const int group_id = in_c % sampling_group;
474 | const int rotation_offset_temp = (out_b * sampling_group + group_id) *
475 | (filter_height * filter_width * 2) * out_height * out_width +
476 | (out_h * out_width + out_w);
477 | const int filter_offset_temp = in_c * scope_height * scope_width;
478 |
479 | // Finally, we can iterate over the spatial dimensions and perform the
480 | // convolution, writing into the output at the end.
481 | const int input_h_start = out_h * stride_height - pad_height;
482 | const int input_w_start = out_w * stride_width - pad_width;
483 |
484 | scalar_t sum = 0;
485 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
486 | const int in_h = input_h_start + f_h * dilation_height;
487 |
488 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
489 | const int in_w = input_w_start + f_w * dilation_width;
490 | const int deformation_offset_fhw = deformation_offset_temp +
491 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
492 | const int rotation_offset_fhw = rotation_offset_temp +
493 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
494 |
495 | const scalar_t input_h = in_h +
496 | ldg(offset + deformation_offset_fhw);
497 | const scalar_t input_w = in_w +
498 | ldg(offset + deformation_offset_fhw + out_height * out_width);
499 |
500 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) {
501 | const scalar_t rotation_ratio_h =
502 | ldg(rotation_ratio + rotation_offset_fhw);
503 | const scalar_t rotation_ratio_w =
504 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width);
505 |
506 | const scalar_t cur_input = deformable_im2col_bilinear(
507 | input + input_offset_temp,
508 | in_height,
509 | in_width,
510 | input_h,
511 | input_w);
512 |
513 | const scalar_t filter_h = f_h + rotation_ratio_h +
514 | (scope_height - filter_height) / 2.0;
515 | const scalar_t filter_w = f_w + rotation_ratio_w +
516 | (scope_width - filter_width) / 2.0;
517 | sum += cur_input * planar_bilinear(
518 | filter + filter_offset_temp,
519 | scope_height,
520 | scope_width,
521 | filter_h,
522 | filter_w);
523 | }
524 | }
525 | }
526 | output[thread_id] = sum;
527 | }
528 | }
529 |
530 | void DeformableSampleDepthwiseConv2dForward(
531 | const at::Tensor input,
532 | const at::Tensor offset,
533 | const at::Tensor rotation_ratio,
534 | const at::Tensor filter,
535 | const SampleDepthwiseArgs args,
536 | at::Tensor output) {
537 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width;
538 |
539 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
540 | input.type(), "DeformableSampleDepthwiseConv2dForward_GPU", ([&] {
541 | const scalar_t *input_ = input.data();
542 | const scalar_t *offset_ = offset.data();
543 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
544 | const scalar_t *filter_ = filter.data();
545 | scalar_t *output_ = output.data();
546 |
547 | if (args.filter_height == 3 && args.filter_width == 3) {
548 | DeformableSampleDepthwiseConv2dForwardKernel
549 | <<>>(
550 | num_kernels,
551 | input_,
552 | offset_,
553 | rotation_ratio_,
554 | filter_,
555 | args,
556 | output_);
557 | } else {
558 | DeformableSampleDepthwiseConv2dForwardKernel
559 | <<>>(
560 | num_kernels,
561 | input_,
562 | offset_,
563 | rotation_ratio_,
564 | filter_,
565 | args,
566 | output_);
567 | }
568 |
569 | }));
570 |
571 | cudaError_t err = cudaGetLastError();
572 | if (err != cudaSuccess) {
573 | printf("error in DeformableSampleDepthwiseConv2dForwardKernel: %s\n", cudaGetErrorString(err));
574 | }
575 | }
576 |
577 | template
578 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardDataKernel(
579 | int n,
580 | const scalar_t* out_grad,
581 | const scalar_t* rotation_ratio,
582 | const scalar_t* filter,
583 | const SampleDepthwiseArgs args,
584 | scalar_t* in_grad) {
585 |
586 | const int channel = args.channel;
587 | const int in_height = args.in_height;
588 | const int in_width = args.in_width;
589 | const int filter_height = args.filter_height;
590 | const int filter_width = args.filter_width;
591 | const int stride_height = args.stride_height;
592 | const int stride_width = args.stride_width;
593 | const int pad_height = args.pad_height;
594 | const int pad_width = args.pad_width;
595 | const int dilation_height = args.dilation_height;
596 | const int dilation_width = args.dilation_width;
597 | const int out_height = args.out_height;
598 | const int out_width = args.out_width;
599 |
600 | const int scope_height = args.scope_height;
601 | const int scope_width = args.scope_width;
602 | const int sampling_group = args.sampling_group;
603 |
604 | CUDA_KERNEL_LOOP(thread_id, n) {
605 | // Compute the indexes of this thread in the input.
606 | const int in_w = thread_id % in_width;
607 | const int in_h = (thread_id / in_width) % in_height;
608 | const int channel_idx = (thread_id / in_width / in_height) % channel;
609 | const int batch_idx = thread_id / channel / in_width / in_height;
610 |
611 | const int out_h_start = CudaMax(
612 | 0, (in_h + pad_height - (filter_height - 1) * dilation_height +
613 | stride_height - 1) / stride_height);
614 | const int out_h_end = CudaMin(
615 | out_height - 1, (in_h + pad_height) / stride_height);
616 | const int out_w_start = CudaMax(
617 | 0, (in_w + pad_width - (filter_width - 1) * dilation_width +
618 | stride_width - 1) / stride_width);
619 | const int out_w_end = CudaMin(
620 | out_width - 1, (in_w + pad_width) / stride_width);
621 |
622 | const int group_id = channel_idx % sampling_group;
623 | const int rotation_offset_temp =
624 | (batch_idx * sampling_group + group_id) *
625 | (filter_height * filter_width * 2) * out_height * out_width;
626 | const int filter_offset_temp = channel_idx * scope_height * scope_width;
627 | const int out_grad_offset_temp =
628 | (batch_idx * channel + channel_idx) * (out_height * out_width);
629 |
630 | scalar_t sum = 0.0f;
631 | for (int out_h = out_h_start; out_h <= out_h_end; ++out_h) {
632 | int f_h = in_h + pad_height - out_h * stride_height;
633 |
634 | if (f_h % dilation_height == 0) {
635 | f_h /= dilation_height;
636 | const int out_grad_offset_h = out_grad_offset_temp + out_h * out_width;
637 |
638 | for (int out_w = out_w_start; out_w <= out_w_end; ++out_w) {
639 | int f_w = in_w + pad_width - out_w * stride_width;
640 |
641 | if (f_w % dilation_width == 0) {
642 | f_w /= dilation_width;
643 | const int out_grad_offset = out_grad_offset_h + out_w;
644 |
645 | const int rotation_offset_fhw = rotation_offset_temp +
646 | (f_h * filter_width + f_w) * 2 * out_height * out_width +
647 | (out_h * out_width + out_w);
648 | const scalar_t rotation_ratio_h =
649 | ldg(rotation_ratio + rotation_offset_fhw);
650 | const scalar_t rotation_ratio_w =
651 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width);
652 |
653 | const scalar_t filter_h = f_h + rotation_ratio_h +
654 | (scope_height - filter_height) / 2.0;
655 | const scalar_t filter_w = f_w + rotation_ratio_w +
656 | (scope_width - filter_width) / 2.0;
657 | sum += ldg(out_grad + out_grad_offset) * planar_bilinear(
658 | filter + filter_offset_temp,
659 | scope_height,
660 | scope_width,
661 | filter_h,
662 | filter_w);
663 | }
664 | }
665 | }
666 | }
667 | in_grad[thread_id] = sum;
668 | }
669 | }
670 |
671 | void SampleDepthwiseConv2dBackwardData(
672 | const at::Tensor out_grad,
673 | const at::Tensor rotation_ratio,
674 | const at::Tensor filter,
675 | const SampleDepthwiseArgs args,
676 | at::Tensor in_grad) {
677 |
678 | int num_kernels = args.batch * args.channel * args.in_height * args.in_width;
679 |
680 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
681 | out_grad.type(), "SampleDepthwiseConv2dBackwardData_GPU", ([&] {
682 | const scalar_t *out_grad_ = out_grad.data();
683 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
684 | const scalar_t *filter_ = filter.data();
685 | scalar_t *in_grad_ = in_grad.data();
686 |
687 | SampleDepthwiseConv2dBackwardDataKernel
688 | <<>>(
689 | num_kernels,
690 | out_grad_,
691 | rotation_ratio_,
692 | filter_,
693 | args,
694 | in_grad_);
695 |
696 | }));
697 |
698 | cudaError_t err = cudaGetLastError();
699 | if (err != cudaSuccess) {
700 | printf("error in SampleDepthwiseConv2dBackwardDataKernel: %s\n", cudaGetErrorString(err));
701 | }
702 | }
703 |
704 | template
705 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardDataKernel(
706 | int n,
707 | const scalar_t* out_backprop,
708 | const scalar_t* offset,
709 | const scalar_t* rotation_ratio,
710 | const scalar_t* filter,
711 | const SampleDepthwiseArgs args,
712 | scalar_t* in_grad) {
713 |
714 | const int channel = args.channel;
715 | const int in_height = args.in_height;
716 | const int in_width = args.in_width;
717 | const int filter_height = args.filter_height;
718 | const int filter_width = args.filter_width;
719 | const int stride_height = args.stride_height;
720 | const int stride_width = args.stride_width;
721 | const int pad_height = args.pad_height;
722 | const int pad_width = args.pad_width;
723 | const int dilation_height = args.dilation_height;
724 | const int dilation_width = args.dilation_width;
725 | const int out_height = args.out_height;
726 | const int out_width = args.out_width;
727 |
728 | const int scope_height = args.scope_height;
729 | const int scope_width = args.scope_width;
730 | const int sampling_group = args.sampling_group;
731 |
732 | CUDA_KERNEL_LOOP(thread_id, n) {
733 | // Compute the indexes of this thread in the output.
734 | const int out_w = thread_id % out_width;
735 | const int out_h = (thread_id / out_width) % out_height;
736 | const int in_c = (thread_id / out_width / out_height) % channel;
737 | // NOTE(Hang Gao @ 07/26): why feed data like this? -- because
738 | const int f_w = (thread_id / out_width / out_height / channel) % filter_width;
739 | const int f_h = (thread_id / out_width / out_height / channel / filter_width) % filter_height;
740 | const int out_b = (thread_id / out_width / out_height / channel / filter_width) / filter_height;
741 |
742 | // Decide if all input is valid, if yes, we can skip the boundary checks
743 | // for each input.
744 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height;
745 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width;
746 |
747 | const int deformable_offset_temp =
748 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 *
749 | out_height * out_width + (out_h * out_width + out_w);
750 | const int group_id = in_c % sampling_group;
751 | const int rotation_offset_temp =
752 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) +
753 | (f_h * filter_width + f_w)) * 2 * out_height * out_width +
754 | (out_h * out_width + out_w);
755 | const scalar_t input_h = in_row + ldg(offset + deformable_offset_temp);
756 | const scalar_t input_w = in_col + ldg(
757 | offset + deformable_offset_temp + out_height * out_width);
758 |
759 | // Avoid repeated computation.
760 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) {
761 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width);
762 | const int filter_offset_temp = in_c * scope_height * scope_width;
763 | const scalar_t out_bp = ldg(
764 | out_backprop +
765 | (out_b * channel + in_c) * (out_height * out_width) +
766 | (out_h * out_width + out_w));
767 |
768 | const scalar_t rotation_ratio_h = ldg(
769 | rotation_ratio + rotation_offset_temp);
770 | const scalar_t rotation_ratio_w = ldg(
771 | rotation_ratio + rotation_offset_temp + out_height * out_width);
772 |
773 | scalar_t cur_weight = 0;
774 | const scalar_t filter_h = f_h + rotation_ratio_h +
775 | (scope_height - filter_height) / 2.0;
776 | const scalar_t filter_w = f_w + rotation_ratio_w +
777 | (scope_width - filter_width) / 2.0;
778 | cur_weight = planar_bilinear(
779 | filter + filter_offset_temp,
780 | scope_height,
781 | scope_width,
782 | filter_h,
783 | filter_w);
784 |
785 | const scalar_t partial_sum = cur_weight * out_bp;
786 | deformable_im2col_bilinear_backward(
787 | partial_sum,
788 | input_h,
789 | input_w,
790 | in_height,
791 | in_width,
792 | in_grad + input_offset_temp);
793 | }
794 | }
795 | }
796 |
797 | void DeformableSampleDepthwiseConv2dBackwardData(
798 | const at::Tensor out_grad,
799 | const at::Tensor offset,
800 | const at::Tensor rotation_ratio,
801 | const at::Tensor filter,
802 | const SampleDepthwiseArgs args,
803 | at::Tensor in_grad) {
804 |
805 | int num_kernels = args.batch * args.filter_height * args.filter_width *
806 | args.channel * args.out_height * args.out_width;
807 |
808 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
809 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardData_GPU", ([&] {
810 | const scalar_t *out_grad_ = out_grad.data();
811 | const scalar_t *offset_ = offset.data();
812 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
813 | const scalar_t *filter_ = filter.data();
814 | scalar_t *in_grad_ = in_grad.data();
815 |
816 | DeformableSampleDepthwiseConv2dBackwardDataKernel
817 | <<>>(
818 | num_kernels,
819 | out_grad_,
820 | offset_,
821 | rotation_ratio_,
822 | filter_,
823 | args,
824 | in_grad_);
825 |
826 | }));
827 |
828 | cudaError_t err = cudaGetLastError();
829 | if (err != cudaSuccess) {
830 | printf(
831 | "error in DeformableSampleDepthwiseConv2dBackwardDataKernel: %s\n",
832 | cudaGetErrorString(err));
833 | }
834 | }
835 |
836 | template
837 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardFilterKernel(
838 | int n,
839 | const scalar_t* out_backprop,
840 | const scalar_t* input,
841 | const scalar_t* rotation_ratio,
842 | const SampleDepthwiseArgs args,
843 | scalar_t* filter_backprop) {
844 |
845 | const int channel = args.channel;
846 | const int in_height = args.in_height;
847 | const int in_width = args.in_width;
848 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
849 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
850 | const int stride_height = args.stride_height;
851 | const int stride_width = args.stride_width;
852 | const int pad_height = args.pad_height;
853 | const int pad_width = args.pad_width;
854 | const int dilation_height = args.dilation_height;
855 | const int dilation_width = args.dilation_width;
856 | const int out_height = args.out_height;
857 | const int out_width = args.out_width;
858 |
859 | const int scope_height = args.scope_height;
860 | const int scope_width = args.scope_width;
861 | const int sampling_group = args.sampling_group;
862 |
863 | CUDA_KERNEL_LOOP(thread_id, n) {
864 | // Compute the indexes of this thread in the output.
865 | const int out_w = thread_id % out_width;
866 | const int out_h = (thread_id / out_width) % out_height;
867 | const int out_c = (thread_id / out_width / out_height) % channel;
868 | const int out_b = thread_id / out_width / out_height / channel;
869 | const int in_c = out_c;
870 |
871 | // Decide if all input is valid, if yes, we can skip the boundary checks
872 | // for each input.
873 | const int in_row_start = out_h * stride_height - pad_height;
874 | const int in_col_start = out_w * stride_width - pad_width;
875 | const int in_row_end = in_row_start + (filter_height - 1) * dilation_height;
876 | const int in_col_end = in_col_start + (filter_width - 1) * dilation_width;
877 |
878 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width);
879 | const int group_id = in_c % sampling_group;
880 | const int rotation_offset_temp = (out_b * sampling_group + group_id) *
881 | (filter_height * filter_width * 2) * out_height * out_width +
882 | (out_h * out_width + out_w);
883 | const int filter_offset_temp = in_c * scope_height * scope_width;
884 |
885 | const scalar_t out_bp = ldg(out_backprop + thread_id);
886 | if (in_row_start >= 0 && in_col_start >= 0 &&
887 | in_row_end < in_height && in_col_end < in_width) {
888 |
889 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
890 | const int in_row = in_row_start + f_h * dilation_height;
891 | // Avoid repeated computation.
892 | const int input_offset_local = input_offset_temp + in_row * in_width;
893 |
894 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
895 | const int in_col = in_col_start + f_w * dilation_width;
896 | const int input_offset = input_offset_local + in_col;
897 |
898 | const int rotation_offset_fhw =
899 | rotation_offset_temp + (f_h * filter_width + f_w) * 2 * out_height * out_width;
900 | const scalar_t rotation_ratio_h = ldg(
901 | rotation_ratio + rotation_offset_fhw);
902 | const scalar_t rotation_ratio_w = ldg(
903 | rotation_ratio + rotation_offset_fhw + out_height * out_width);
904 |
905 | scalar_t partial_sum = ldg(input + input_offset) * out_bp;
906 | const scalar_t filter_h = f_h + rotation_ratio_h +
907 | (scope_height - filter_height) / 2.0;
908 | const scalar_t filter_w = f_w + rotation_ratio_w +
909 | (scope_width - filter_width) / 2.0;
910 | planar_bilinear_backward_data(
911 | partial_sum,
912 | scope_height,
913 | scope_width,
914 | filter_h,
915 | filter_w,
916 | filter_backprop + filter_offset_temp);
917 | }
918 | }
919 | } else {
920 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
921 | const int in_row = in_row_start + f_h * dilation_height;
922 | // Avoid repeated computation.
923 | const int input_offset_local = input_offset_temp + in_row * in_width;
924 |
925 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
926 | const int in_col = in_col_start + f_w * dilation_width;;
927 |
928 | if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) {
929 | const int input_offset = input_offset_local + in_col;
930 |
931 | const int rotation_offset_fhw = rotation_offset_temp +
932 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
933 | const scalar_t rotation_ratio_h = ldg(
934 | rotation_ratio + rotation_offset_fhw);
935 | const scalar_t rotation_ratio_w = ldg(
936 | rotation_ratio + rotation_offset_fhw + out_height * out_width);
937 |
938 | scalar_t partial_sum = ldg(input + input_offset) * out_bp;
939 | const scalar_t filter_h = f_h + rotation_ratio_h +
940 | (scope_height - filter_height) / 2.0;
941 | const scalar_t filter_w = f_w + rotation_ratio_w +
942 | (scope_width - filter_width) / 2.0;
943 | planar_bilinear_backward_data(
944 | partial_sum,
945 | scope_height,
946 | scope_width,
947 | filter_h,
948 | filter_w,
949 | filter_backprop + filter_offset_temp);
950 | }
951 | }
952 | }
953 | }
954 | }
955 | }
956 |
957 | void SampleDepthwiseConv2dBackwardFilter(
958 | const at::Tensor out_grad,
959 | const at::Tensor input,
960 | const at::Tensor rotation_ratio,
961 | const SampleDepthwiseArgs args,
962 | at::Tensor filter_grad) {
963 |
964 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width;
965 |
966 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
967 | out_grad.type(), "SampleDepthwiseConv2dBackwardFilter_GPU", ([&] {
968 | const scalar_t *out_grad_ = out_grad.data();
969 | const scalar_t *input_ = input.data();
970 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
971 | scalar_t *filter_grad_ = filter_grad.data();
972 |
973 | if (args.filter_height == 3 && args.filter_width == 3) {
974 | SampleDepthwiseConv2dBackwardFilterKernel
975 | <<>>(
976 | num_kernels,
977 | out_grad_,
978 | input_,
979 | rotation_ratio_,
980 | args,
981 | filter_grad_);
982 | } else {
983 | SampleDepthwiseConv2dBackwardFilterKernel
984 | <<>>(
985 | num_kernels,
986 | out_grad_,
987 | input_,
988 | rotation_ratio_,
989 | args,
990 | filter_grad_);
991 | }
992 |
993 | }));
994 |
995 | cudaError_t err = cudaGetLastError();
996 | if (err != cudaSuccess) {
997 | printf("error in SampleDepthwiseConv2dBackwardFilterKernel: %s\n", cudaGetErrorString(err));
998 | }
999 | }
1000 |
1001 | // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
1002 | template
1003 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardFilterKernel(
1004 | int n,
1005 | const scalar_t* out_backprop,
1006 | const scalar_t* input,
1007 | const scalar_t* offset,
1008 | const scalar_t* rotation_ratio,
1009 | const SampleDepthwiseArgs args,
1010 | scalar_t* filter_backprop) {
1011 |
1012 | const int channel = args.channel;
1013 | const int in_height = args.in_height;
1014 | const int in_width = args.in_width;
1015 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
1016 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
1017 | const int stride_height = args.stride_height;
1018 | const int stride_width = args.stride_width;
1019 | const int pad_height = args.pad_height;
1020 | const int pad_width = args.pad_width;
1021 | const int dilation_height = args.dilation_height;
1022 | const int dilation_width = args.dilation_width;
1023 | const int out_height = args.out_height;
1024 | const int out_width = args.out_width;
1025 |
1026 | const int scope_height = args.scope_height;
1027 | const int scope_width = args.scope_width;
1028 | const int sampling_group = args.sampling_group;
1029 |
1030 | CUDA_KERNEL_LOOP(thread_id, n) {
1031 | // Compute the indexes of this thread in the output.
1032 | const int out_w = thread_id % out_width;
1033 | const int out_h = (thread_id / out_width) % out_height;
1034 | const int out_c = (thread_id / out_width / out_height) % channel;
1035 | const int out_b = thread_id / out_width / out_height / channel;
1036 | const int in_c = out_c;
1037 |
1038 | // Decide if all input is valid, if yes, we can skip the boundary checks
1039 | // for each input.
1040 | const int in_row_start = out_h * stride_height - pad_height;
1041 | const int in_col_start = out_w * stride_width - pad_width;
1042 |
1043 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width);
1044 | const int deformation_offset_temp = out_b * (filter_height * filter_width * 2) *
1045 | out_height * out_width + (out_h * out_width + out_w);
1046 | const int group_id = in_c % sampling_group;
1047 | const int rotation_offset_temp = (out_b * sampling_group + group_id) *
1048 | (filter_height * filter_width * 2) * out_height * out_width +
1049 | (out_h * out_width + out_w);
1050 | const int filter_offset_temp = in_c * scope_height * scope_width;
1051 |
1052 | const scalar_t out_bp = ldg(out_backprop + thread_id);
1053 |
1054 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
1055 | const int in_row = in_row_start + f_h * dilation_height;
1056 |
1057 | // Avoid repeated computation.
1058 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
1059 | const int in_col = in_col_start + f_w * dilation_width;
1060 |
1061 | const int deformation_offset_fhw = deformation_offset_temp +
1062 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
1063 | const int rotation_offset_fhw = rotation_offset_temp +
1064 | (f_h * filter_width + f_w) * 2 * out_height * out_width;
1065 | const scalar_t input_h = in_row + ldg(
1066 | offset + deformation_offset_fhw);
1067 | const scalar_t input_w = in_col + ldg(
1068 | offset + deformation_offset_fhw + out_height * out_width);
1069 |
1070 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) {
1071 | const scalar_t rotation_ratio_h = ldg(
1072 | rotation_ratio + rotation_offset_fhw);
1073 | const scalar_t rotation_ratio_w = ldg(
1074 | rotation_ratio + rotation_offset_fhw + out_height * out_width);
1075 |
1076 | const scalar_t partial_sum = deformable_im2col_bilinear(
1077 | input + input_offset_temp,
1078 | in_height,
1079 | in_width,
1080 | input_h,
1081 | input_w) * out_bp;
1082 | const scalar_t filter_h = f_h + rotation_ratio_h +
1083 | (scope_height - filter_height) / 2.0;
1084 | const scalar_t filter_w = f_w + rotation_ratio_w +
1085 | (scope_width - filter_width) / 2.0;
1086 | planar_bilinear_backward_data(
1087 | partial_sum,
1088 | scope_height,
1089 | scope_width,
1090 | filter_h,
1091 | filter_w,
1092 | filter_backprop + filter_offset_temp);
1093 | }
1094 | }
1095 | }
1096 | }
1097 | }
1098 |
1099 | void DeformableSampleDepthwiseConv2dBackwardFilter(
1100 | const at::Tensor out_grad,
1101 | const at::Tensor input,
1102 | const at::Tensor offset,
1103 | const at::Tensor rotation_ratio,
1104 | const SampleDepthwiseArgs args,
1105 | at::Tensor filter_grad) {
1106 |
1107 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width;
1108 |
1109 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1110 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardFilter_GPU", ([&] {
1111 | const scalar_t *out_grad_ = out_grad.data();
1112 | const scalar_t *input_ = input.data();
1113 | const scalar_t *offset_ = offset.data();
1114 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
1115 | scalar_t *filter_grad_ = filter_grad.data();
1116 |
1117 | if (args.filter_height == 3 && args.filter_width == 3) {
1118 | DeformableSampleDepthwiseConv2dBackwardFilterKernel
1119 | <<>>(
1120 | num_kernels,
1121 | out_grad_,
1122 | input_,
1123 | offset_,
1124 | rotation_ratio_,
1125 | args,
1126 | filter_grad_);
1127 | } else {
1128 | DeformableSampleDepthwiseConv2dBackwardFilterKernel
1129 | <<>>(
1130 | num_kernels,
1131 | out_grad_,
1132 | input_,
1133 | offset_,
1134 | rotation_ratio_,
1135 | args,
1136 | filter_grad_);
1137 | }
1138 |
1139 | }));
1140 |
1141 | cudaError_t err = cudaGetLastError();
1142 | if (err != cudaSuccess) {
1143 | printf("error in DeformableSampleDepthwiseConv2dBackwardFilterKernel: %s\n", cudaGetErrorString(err));
1144 | }
1145 | }
1146 |
1147 | template
1148 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardOffsetKernel(
1149 | int n,
1150 | const scalar_t* out_backprop,
1151 | const scalar_t* input,
1152 | const scalar_t* offset,
1153 | const scalar_t* rotation_ratio,
1154 | const scalar_t* filter,
1155 | const SampleDepthwiseArgs args,
1156 | scalar_t* offset_backprop) {
1157 |
1158 | const int channel = args.channel;
1159 | const int in_height = args.in_height;
1160 | const int in_width = args.in_width;
1161 | const int filter_height = args.filter_height;
1162 | const int filter_width = args.filter_width;
1163 | const int stride_height = args.stride_height;
1164 | const int stride_width = args.stride_width;
1165 | const int pad_height = args.pad_height;
1166 | const int pad_width = args.pad_width;
1167 | const int dilation_height = args.dilation_height;
1168 | const int dilation_width = args.dilation_width;
1169 | const int out_height = args.out_height;
1170 | const int out_width = args.out_width;
1171 |
1172 | const int scope_height = args.scope_height;
1173 | const int scope_width = args.scope_width;
1174 | const int sampling_group = args.sampling_group;
1175 |
1176 | CUDA_KERNEL_LOOP(thread_id, n) {
1177 | // Compute the indexes of this thread in the output.
1178 | const int out_w = thread_id % out_width;
1179 | const int out_h = (thread_id / out_width) % out_height;
1180 | const int bp_dir = (thread_id / out_width / out_height) % 2;
1181 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width;
1182 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height;
1183 | const int out_b = (thread_id / out_width / out_height / 2 / filter_width) / filter_height;
1184 |
1185 | // Decide if all input is valid, if yes, we can skip the boundary checks
1186 | // for each input.
1187 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height;
1188 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width;
1189 |
1190 | const int deformable_offset_temp =
1191 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 *
1192 | out_height * out_width +
1193 | (out_h * out_width + out_w);
1194 | const scalar_t input_h = in_row + ldg(
1195 | offset + deformable_offset_temp);
1196 | const scalar_t input_w = in_col + ldg(
1197 | offset + deformable_offset_temp + out_height * out_width);
1198 |
1199 | scalar_t coord_gradient = 0;
1200 | // Avoid repeated computation.
1201 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) {
1202 |
1203 | for (int in_c = 0; in_c < channel; in_c++) {
1204 | const int group_id = in_c % sampling_group;
1205 | const int rotation_offset_temp = ((out_b * sampling_group + group_id) *
1206 | (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 *
1207 | out_height * out_width + (out_h * out_width + out_w);
1208 | const scalar_t rotation_ratio_h = ldg(
1209 | rotation_ratio + rotation_offset_temp);
1210 | const scalar_t rotation_ratio_w = ldg(
1211 | rotation_ratio + rotation_offset_temp + out_height * out_width);
1212 | scalar_t filter_h = f_h + rotation_ratio_h +
1213 | (scope_height - filter_height) / 2.0;
1214 | scalar_t filter_w = f_w + rotation_ratio_w +
1215 | (scope_width - filter_width) / 2.0;
1216 |
1217 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width);
1218 | const int filter_offset_temp = in_c * scope_height * scope_width;
1219 | const scalar_t out_bp = ldg(
1220 | out_backprop +
1221 | (out_b * channel + in_c) * (out_height * out_width) +
1222 | (out_h * out_width + out_w));
1223 |
1224 | scalar_t cur_weight = planar_bilinear(
1225 | filter + filter_offset_temp,
1226 | scope_height,
1227 | scope_width,
1228 | filter_h,
1229 | filter_w);
1230 | scalar_t partial_sum = cur_weight * out_bp;
1231 | coord_gradient += get_coordinate_weight(
1232 | input_h,
1233 | input_w,
1234 | in_height,
1235 | in_width,
1236 | input + input_offset_temp,
1237 | bp_dir) * partial_sum;
1238 | }
1239 | }
1240 |
1241 | offset_backprop[thread_id] = coord_gradient;
1242 | }
1243 | }
1244 |
1245 | void DeformableSampleDepthwiseConv2dBackwardOffset(
1246 | const at::Tensor out_grad,
1247 | const at::Tensor input,
1248 | const at::Tensor offset,
1249 | const at::Tensor rotation_ratio,
1250 | const at::Tensor filter,
1251 | const SampleDepthwiseArgs args,
1252 | at::Tensor offset_grad) {
1253 |
1254 | int num_kernels = args.batch * args.filter_height * args.filter_width * 2 *
1255 | args.out_height * args.out_width;
1256 |
1257 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1258 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardOffset_GPU", ([&] {
1259 | const scalar_t *out_grad_ = out_grad.data();
1260 | const scalar_t *input_ = input.data();
1261 | const scalar_t *offset_ = offset.data();
1262 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
1263 | const scalar_t *filter_ = filter.data();
1264 | scalar_t *offset_grad_ = offset_grad.data();
1265 |
1266 | DeformableSampleDepthwiseConv2dBackwardOffsetKernel
1267 | <<>>(
1268 | num_kernels,
1269 | out_grad_,
1270 | input_,
1271 | offset_,
1272 | rotation_ratio_,
1273 | filter_,
1274 | args,
1275 | offset_grad_);
1276 |
1277 | }));
1278 |
1279 | cudaError_t err = cudaGetLastError();
1280 | if (err != cudaSuccess) {
1281 | printf("error in DeformableSampleDepthwiseConv2dBackwardOffsetKernel: %s\n", cudaGetErrorString(err));
1282 | }
1283 | }
1284 |
1285 | template
1286 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardRotationKernel(
1287 | int n,
1288 | const scalar_t* out_backprop,
1289 | const scalar_t* input,
1290 | const scalar_t* rotation_ratio,
1291 | const scalar_t* filter,
1292 | const SampleDepthwiseArgs args,
1293 | scalar_t* rotation_backprop) {
1294 |
1295 | const int channel = args.channel;
1296 | const int in_height = args.in_height;
1297 | const int in_width = args.in_width;
1298 | const int filter_height = args.filter_height;
1299 | const int filter_width = args.filter_width;
1300 | const int stride_height = args.stride_height;
1301 | const int stride_width = args.stride_width;
1302 | const int pad_height = args.pad_height;
1303 | const int pad_width = args.pad_width;
1304 | const int dilation_height = args.dilation_height;
1305 | const int dilation_width = args.dilation_width;
1306 | const int out_height = args.out_height;
1307 | const int out_width = args.out_width;
1308 |
1309 | const int scope_height = args.scope_height;
1310 | const int scope_width = args.scope_width;
1311 | const int sampling_group = args.sampling_group;
1312 |
1313 | CUDA_KERNEL_LOOP(thread_id, n) {
1314 | // Compute the indexes of this thread in the output.
1315 | const int out_w = thread_id % out_width;
1316 | const int out_h = (thread_id / out_width) % out_height;
1317 | const int bp_dir = (thread_id / out_width / out_height) % 2;
1318 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width;
1319 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height;
1320 | const int group_id = (thread_id / out_width / out_height / 2 /
1321 | filter_width / filter_height) % sampling_group;
1322 | const int out_b = (thread_id / out_width / out_height / 2 /
1323 | filter_width / filter_height) / sampling_group;
1324 |
1325 | // Decide if all input is valid, if yes, we can skip the boundary checks
1326 | // for each input.
1327 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height;
1328 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width;
1329 |
1330 | const int rotation_offset_temp =
1331 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) +
1332 | (f_h * filter_width + f_w)) * 2 * out_height * out_width +
1333 | (out_h * out_width + out_w);
1334 | const scalar_t rotation_ratio_h = ldg(
1335 | rotation_ratio + rotation_offset_temp);
1336 | const scalar_t rotation_ratio_w = ldg(
1337 | rotation_ratio + rotation_offset_temp + out_height * out_width);
1338 | scalar_t filter_h = f_h + rotation_ratio_h +
1339 | (scope_height - filter_height) / 2.0;
1340 | scalar_t filter_w = f_w + rotation_ratio_w +
1341 | (scope_width - filter_width) / 2.0;
1342 |
1343 | scalar_t coord_gradient = 0;
1344 | // Avoid repeated computation.
1345 | if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) {
1346 | for (int in_c = group_id; in_c < channel; in_c += sampling_group) {
1347 | const int input_offset_temp =
1348 | (out_b * channel + in_c) * (in_height * in_width) +
1349 | (in_row * in_width + in_col);
1350 | const int filter_offset_temp = in_c * scope_height * scope_width;
1351 | const scalar_t out_bp = ldg(
1352 | out_backprop +
1353 | (out_b * channel + in_c) * (out_height * out_width) +
1354 | (out_h * out_width + out_w));
1355 |
1356 | scalar_t partial_sum = ldg(input + input_offset_temp) * out_bp;
1357 | coord_gradient += planar_bilinear_backward_coord(
1358 | partial_sum,
1359 | filter + filter_offset_temp,
1360 | scope_height,
1361 | scope_width,
1362 | filter_h,
1363 | filter_w,
1364 | bp_dir);
1365 | }
1366 | }
1367 |
1368 | rotation_backprop[thread_id] = coord_gradient;
1369 | }
1370 | }
1371 |
1372 | void SampleDepthwiseConv2dBackwardRotation(
1373 | const at::Tensor out_grad,
1374 | const at::Tensor input,
1375 | const at::Tensor rotation_ratio,
1376 | const at::Tensor filter,
1377 | const SampleDepthwiseArgs args,
1378 | at::Tensor rotation_grad) {
1379 |
1380 | int num_kernels = args.batch *
1381 | args.sampling_group * args.filter_height * args.filter_width * 2 *
1382 | args.out_height * args.out_width;
1383 |
1384 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1385 | out_grad.type(), "SampleDepthwiseConv2dBackwardRotation_GPU", ([&] {
1386 | const scalar_t *out_grad_ = out_grad.data();
1387 | const scalar_t *input_ = input.data();
1388 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
1389 | const scalar_t *filter_ = filter.data();
1390 | scalar_t *rotation_grad_ = rotation_grad.data();
1391 |
1392 | SampleDepthwiseConv2dBackwardRotationKernel
1393 | <<>>(
1394 | num_kernels,
1395 | out_grad_,
1396 | input_,
1397 | rotation_ratio_,
1398 | filter_,
1399 | args,
1400 | rotation_grad_);
1401 |
1402 | }));
1403 |
1404 | cudaError_t err = cudaGetLastError();
1405 | if (err != cudaSuccess) {
1406 | printf("error in SampleDepthwiseConv2dBackwardRotationKernel: %s\n", cudaGetErrorString(err));
1407 | }
1408 | }
1409 |
1410 | template
1411 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardRotationKernel(
1412 | int n,
1413 | const scalar_t* out_backprop,
1414 | const scalar_t* input,
1415 | const scalar_t* offset,
1416 | const scalar_t* rotation_ratio,
1417 | const scalar_t* filter,
1418 | const SampleDepthwiseArgs args,
1419 | scalar_t* rotation_backprop) {
1420 |
1421 | const int channel = args.channel;
1422 | const int in_height = args.in_height;
1423 | const int in_width = args.in_width;
1424 | const int filter_height = args.filter_height;
1425 | const int filter_width = args.filter_width;
1426 | const int stride_height = args.stride_height;
1427 | const int stride_width = args.stride_width;
1428 | const int pad_height = args.pad_height;
1429 | const int pad_width = args.pad_width;
1430 | const int dilation_height = args.dilation_height;
1431 | const int dilation_width = args.dilation_width;
1432 | const int out_height = args.out_height;
1433 | const int out_width = args.out_width;
1434 |
1435 | const int scope_height = args.scope_height;
1436 | const int scope_width = args.scope_width;
1437 | const int sampling_group = args.sampling_group;
1438 |
1439 | CUDA_KERNEL_LOOP(thread_id, n) {
1440 | // Compute the indexes of this thread in the output.
1441 | const int out_w = thread_id % out_width;
1442 | const int out_h = (thread_id / out_width) % out_height;
1443 | const int bp_dir = (thread_id / out_width / out_height) % 2;
1444 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width;
1445 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height;
1446 | const int group_id = (thread_id / out_width / out_height / 2 /
1447 | filter_width / filter_height) % sampling_group;
1448 | const int out_b = (thread_id / out_width / out_height / 2 /
1449 | filter_width / filter_height) / sampling_group;
1450 |
1451 | // Decide if all input is valid, if yes, we can skip the boundary checks
1452 | // for each input.
1453 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height;
1454 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width;
1455 |
1456 | const int deformable_offset_temp =
1457 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 *
1458 | out_height * out_width + (out_h * out_width + out_w);
1459 | const int rotation_offset_temp =
1460 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) +
1461 | (f_h * filter_width + f_w)) * 2 * out_height * out_width +
1462 | (out_h * out_width + out_w);
1463 | const scalar_t input_h = in_row + ldg(
1464 | offset + deformable_offset_temp);
1465 | const scalar_t input_w = in_col + ldg(
1466 | offset + deformable_offset_temp + out_height * out_width);
1467 |
1468 | scalar_t coord_gradient = 0;
1469 | // Avoid repeated computation.
1470 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) {
1471 | const scalar_t rotation_ratio_h = ldg(
1472 | rotation_ratio + rotation_offset_temp);
1473 | const scalar_t rotation_ratio_w = ldg(
1474 | rotation_ratio + rotation_offset_temp + out_height * out_width);
1475 | scalar_t filter_h = f_h + rotation_ratio_h +
1476 | (scope_height - filter_height) / 2.0;
1477 | scalar_t filter_w = f_w + rotation_ratio_w +
1478 | (scope_width - filter_width) / 2.0;
1479 |
1480 | for (int in_c = group_id; in_c < channel; in_c += sampling_group) {
1481 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width);
1482 | const int filter_offset_temp = in_c * scope_height * scope_width;
1483 | const scalar_t out_bp = ldg(
1484 | out_backprop +
1485 | (out_b * channel + in_c) * (out_height * out_width) +
1486 | (out_h * out_width + out_w));
1487 |
1488 | scalar_t partial_sum = deformable_im2col_bilinear(
1489 | input + input_offset_temp,
1490 | in_height,
1491 | in_width,
1492 | input_h,
1493 | input_w) * out_bp;
1494 | coord_gradient += planar_bilinear_backward_coord(
1495 | partial_sum,
1496 | filter + filter_offset_temp,
1497 | scope_height,
1498 | scope_width,
1499 | filter_h,
1500 | filter_w,
1501 | bp_dir);
1502 | }
1503 | }
1504 |
1505 | rotation_backprop[thread_id] = coord_gradient;
1506 | }
1507 | }
1508 |
1509 | void DeformableSampleDepthwiseConv2dBackwardRotation(
1510 | const at::Tensor out_grad,
1511 | const at::Tensor input,
1512 | const at::Tensor offset,
1513 | const at::Tensor rotation_ratio,
1514 | const at::Tensor filter,
1515 | const SampleDepthwiseArgs args,
1516 | at::Tensor rotation_grad) {
1517 |
1518 | int num_kernels = args.batch *
1519 | args.sampling_group * args.filter_height * args.filter_width * 2 *
1520 | args.out_height * args.out_width;
1521 |
1522 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1523 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardRotation_GPU", ([&] {
1524 | const scalar_t *out_grad_ = out_grad.data();
1525 | const scalar_t *input_ = input.data();
1526 | const scalar_t *offset_ = offset.data();
1527 | const scalar_t *rotation_ratio_ = rotation_ratio.data();
1528 | const scalar_t *filter_ = filter.data();
1529 | scalar_t *rotation_grad_ = rotation_grad.data();
1530 |
1531 | DeformableSampleDepthwiseConv2dBackwardRotationKernel
1532 | <<>>(
1533 | num_kernels,
1534 | out_grad_,
1535 | input_,
1536 | offset_,
1537 | rotation_ratio_,
1538 | filter_,
1539 | args,
1540 | rotation_grad_);
1541 |
1542 | }));
1543 |
1544 | cudaError_t err = cudaGetLastError();
1545 | if (err != cudaSuccess) {
1546 | printf("error in DeformableSampleDepthwiseConv2dBackwardRotationKernel: %s\n", cudaGetErrorString(err));
1547 | }
1548 | }
1549 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include "nd_linear_sample_cuda.h"
7 |
8 | /*------------------------------------------------------------------------------------------------------------*/
9 |
10 | int nd_linear_sample_forward_cuda(at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor output) {
11 | data = data.contiguous();
12 | shape = shape.contiguous();
13 | coord = coord.contiguous();
14 |
15 | SampleArgs args;
16 | args.batch = coord.size(0);
17 | args.channel = data.size(0);
18 | args.spatial_dims = data.dim() - 1;
19 | args.prod_shape = data.stride(0);
20 |
21 | output = output.view({args.batch, args.channel});
22 |
23 | NdLinearSampleForward(data, shape, coord, args, output);
24 |
25 | return 1;
26 | }
27 |
28 |
29 | int nd_linear_sample_backward_data_cuda(at::Tensor out_grad, at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor in_grad) {
30 | out_grad = out_grad.contiguous();
31 | shape = shape.contiguous();
32 | coord = coord.contiguous();
33 |
34 | SampleArgs args;
35 | args.batch = coord.size(0);
36 | args.channel = data.size(0);
37 | args.spatial_dims = data.dim() - 1;
38 | args.prod_shape = data.stride(0);
39 |
40 | in_grad = in_grad.view_as(data).zero_();
41 |
42 | NdLinearSampleBackwardData(out_grad, shape, coord, args, in_grad);
43 |
44 | return 1;
45 | }
46 |
47 | int nd_linear_sample_backward_coord_cuda(at::Tensor out_grad, at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor coord_grad_c) {
48 | out_grad = out_grad.contiguous();
49 | data = data.contiguous();
50 | shape = shape.contiguous();
51 | coord = coord.contiguous();
52 |
53 | SampleArgs args;
54 | args.batch = coord.size(0);
55 | args.channel = data.size(0);
56 | args.spatial_dims = data.dim() - 1;
57 | args.prod_shape = data.stride(0);
58 |
59 | coord_grad_c = coord_grad_c.view({args.batch, args.spatial_dims, args.channel}).zero_();
60 |
61 | NdLinearSampleBackwardCoord(out_grad, data, shape, coord, args, coord_grad_c);
62 |
63 | return 1;
64 | }
65 |
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68 | m.def("nd_linear_sample_forward_cuda", &nd_linear_sample_forward_cuda,
69 | "nd_linear_sample forward (CUDA)");
70 | m.def("nd_linear_sample_backward_data_cuda", &nd_linear_sample_backward_data_cuda,
71 | "nd_linear_sample_backward_data (CUDA)");
72 | m.def("nd_linear_sample_backward_coord_cuda", &nd_linear_sample_backward_coord_cuda,
73 | "nd_linear_sample_backward_coord (CUDA)");
74 | }
75 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda.h:
--------------------------------------------------------------------------------
1 | struct SampleArgs {
2 | // Input layer dimensions
3 | int batch;
4 | int channel;
5 | int spatial_dims;
6 | int prod_shape;
7 | };
8 |
9 | void NdLinearSampleForward(const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor output);
10 |
11 | void NdLinearSampleBackwardData(const at::Tensor out_grad, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor in_grad);
12 |
13 | void NdLinearSampleBackwardCoord(const at::Tensor out_grad, const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor coord_grad_c);
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include "nd_linear_sample_cuda.h"
8 |
9 | using namespace at;
10 |
11 | #define CUDA_KERNEL_LOOP(i, n) \
12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
13 | i += blockDim.x * gridDim.x)
14 |
15 | const int CUDA_NUM_THREADS = 1024;
16 | const int kMaxGridNum = 65535;
17 |
18 | inline int GET_BLOCKS(const int N) {
19 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
20 | }
21 |
22 | #if !defined(_MSC_VER)
23 | #define CUDA_UNROLL _Pragma("unroll")
24 | #define CUDA_NOUNROLL _Pragma("nounroll")
25 | #else
26 | #define CUDA_UNROLL
27 | #define CUDA_NOUNROLL
28 | #endif
29 |
30 | template
31 | __device__ inline DType ldg(const DType* address) {
32 | #if __CUDA_ARCH__ >= 350
33 | return __ldg(address);
34 | #else
35 | return *address;
36 | #endif
37 | }
38 |
39 | /*------------------------------------------------------------------------------------------------------------*/
40 |
41 |
42 | // data: d1, d2, ..., dn
43 | // shape: s1, s2, ..., sn
44 | // coord: x1, x2, ..., xn
45 | template
46 | inline __device__ scalar_t nd_linear(const scalar_t *data, const scalar_t* shape, const scalar_t* coord, const int dims) {
47 | for (int d = 0; d < dims; d++) {
48 | scalar_t x = ldg(coord + d);
49 | int s = static_cast(ldg(shape + d));
50 | if (x <= -1 || x >= s) {
51 | return 0;
52 | }
53 | }
54 |
55 | const uint corners = 1 << dims;
56 | scalar_t val = 0;
57 | for (uint i = 0; i < corners; i++) {
58 | int data_offset = 0;
59 | scalar_t data_weight = 1;
60 | bool out_of_scope = false;
61 |
62 | for (uint d = 0; d < dims; d++) {
63 | scalar_t x = ldg(coord + d);
64 | int s = static_cast(ldg(shape + d));
65 |
66 | data_offset *= s;
67 | uint offset = (i >> d) & 1;
68 | int x_low = floor(x);
69 | scalar_t w = x - x_low;
70 |
71 | if (offset == 0 && x_low >= 0) {
72 | data_offset += x_low;
73 | data_weight *= 1-w;
74 | }
75 | else if (offset == 1 && x_low + 1 <= s - 1) {
76 | data_offset += x_low + 1;
77 | data_weight *= w;
78 | }
79 | else {
80 | out_of_scope = true;
81 | break;
82 | }
83 | }
84 | if (!out_of_scope)
85 | val += data_weight * ldg(data + data_offset);
86 | }
87 | return val;
88 | }
89 |
90 | template
91 | inline __device__ void nd_linear_backward_data(const scalar_t top_gradient, const scalar_t* shape, const scalar_t* coord, const int dims, scalar_t* data_gradient) {
92 | for (int d = 0; d < dims; d++) {
93 | scalar_t x = ldg(coord + d);
94 | int s = static_cast(ldg(shape + d));
95 | if (x <= -1 || x >= s) {
96 | return;
97 | }
98 | }
99 |
100 | const uint corners = 1 << dims;
101 | for (uint i = 0; i < corners; i++) {
102 | int data_offset = 0;
103 | scalar_t data_weight = 1;
104 | bool out_of_scope = false;
105 |
106 | for (uint d = 0; d < dims; d++) {
107 | scalar_t x = ldg(coord + d);
108 | int s = static_cast(ldg(shape + d));
109 |
110 | data_offset *= s;
111 | uint offset = (i >> d) & 1;
112 | int x_low = floor(x);
113 | scalar_t w = x - x_low;
114 |
115 | if (offset == 0 && x_low >= 0) {
116 | data_offset += x_low;
117 | data_weight *= 1-w;
118 | }
119 | else if (offset == 1 && x_low + 1 <= s - 1) {
120 | data_offset += x_low + 1;
121 | data_weight *= w;
122 | }
123 | else {
124 | out_of_scope = true;
125 | break;
126 | }
127 | }
128 | if (!out_of_scope)
129 | atomicAdd(data_gradient + data_offset, data_weight * top_gradient);
130 | }
131 | }
132 |
133 | template
134 | inline __device__ scalar_t nd_linear_backward_coord(const scalar_t top_gradient, const scalar_t *data, const scalar_t* shape, const scalar_t* coord, const int dims, const int gdim) {
135 | for (int d = 0; d < dims; d++) {
136 | scalar_t x = ldg(coord + d);
137 | int s = static_cast(ldg(shape + d));
138 | if (x <= -1 || x >= s) {
139 | return 0;
140 | }
141 | }
142 |
143 | scalar_t grad = 0;
144 | const uint corners = 1 << dims;
145 | for (uint i = 0; i < corners; i++) {
146 | int data_offset = 0;
147 | scalar_t data_weight = 1;
148 | bool out_of_scope = false;
149 |
150 | for (uint d = 0; d < dims; d++) {
151 | scalar_t x = ldg(coord + d);
152 | int s = static_cast(ldg(shape + d));
153 |
154 | data_offset *= s;
155 | uint offset = (i >> d) & 1;
156 | int x_low = floor(x);
157 | scalar_t w = x - x_low;
158 |
159 | if (offset == 0 && x_low >= 0) {
160 | data_offset += x_low;
161 | data_weight *= (d == gdim) ? scalar_t(-1) : 1-w;
162 | }
163 | else if (offset == 1 && x_low + 1 <= s - 1) {
164 | data_offset += x_low + 1;
165 | data_weight *= (d == gdim) ? scalar_t(1) : w;
166 | }
167 | else {
168 | out_of_scope = true;
169 | break;
170 | }
171 | }
172 | if (!out_of_scope) {
173 | grad += top_gradient * ldg(data + data_offset) * data_weight;
174 | }
175 | }
176 | return grad;
177 | }
178 |
179 |
180 | /*------------------------------------------------------------------------------------------------------------*/
181 |
182 |
183 | template
184 | __global__ void NdLinearSampleForwardKernel(int n,
185 | const scalar_t* data,
186 | const scalar_t *shape,
187 | const scalar_t* coord,
188 | const SampleArgs args,
189 | scalar_t* output) {
190 |
191 | //const int batch = args.batch;
192 | const int channel = args.channel;
193 | const int spatial_dims = args.spatial_dims;
194 | const int prod_shape = args.prod_shape;
195 |
196 | CUDA_KERNEL_LOOP(thread_id, n) {
197 | const int out_c = thread_id % channel;
198 | const int out_n = thread_id / channel;
199 | output[thread_id] = nd_linear(data + out_c * prod_shape, shape, coord + out_n * spatial_dims, spatial_dims);
200 | }
201 | }
202 |
203 |
204 | void NdLinearSampleForward(const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor output) {
205 | int num_kernels = args.batch * args.channel;
206 |
207 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
208 | data.type(), "NdLinearSampleForward_GPU", ([&] {
209 | const scalar_t *data_ = data.data();
210 | const scalar_t *shape_ = shape.data();
211 | const scalar_t *coord_ = coord.data();
212 | scalar_t *output_ = output.data();
213 |
214 | NdLinearSampleForwardKernel<<>>(
215 | num_kernels, data_, shape_, coord_, args, output_);
216 | }));
217 |
218 | cudaError_t err = cudaGetLastError();
219 | if (err != cudaSuccess) {
220 | printf("error in NdLinearSampleForwardKernel: %s\n", cudaGetErrorString(err));
221 | }
222 | }
223 |
224 |
225 | template
226 | __global__ void NdLinearSampleBackwardDataKernel(int n,
227 | const scalar_t* out_grad,
228 | const scalar_t *shape,
229 | const scalar_t* coord,
230 | const SampleArgs args,
231 | scalar_t* in_grad) {
232 |
233 | //const int batch = args.batch;
234 | const int channel = args.channel;
235 | const int spatial_dims = args.spatial_dims;
236 | const int prod_shape = args.prod_shape;
237 |
238 | CUDA_KERNEL_LOOP(thread_id, n) {
239 | const int out_c = thread_id % channel;
240 | const int out_n = thread_id / channel;
241 | nd_linear_backward_data(ldg(out_grad + thread_id), shape, coord + out_n * spatial_dims, spatial_dims, in_grad + out_c * prod_shape);
242 | }
243 | }
244 |
245 |
246 | void NdLinearSampleBackwardData(const at::Tensor out_grad, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor in_grad) {
247 | int num_kernels = args.batch * args.channel;
248 |
249 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
250 | out_grad.type(), "NdLinearSampleBackwardData_GPU", ([&] {
251 | const scalar_t *out_grad_ = out_grad.data();
252 | const scalar_t *shape_ = shape.data();
253 | const scalar_t *coord_ = coord.data();
254 | scalar_t *in_grad_ = in_grad.data();
255 |
256 | NdLinearSampleBackwardDataKernel<<>>(
257 | num_kernels, out_grad_, shape_, coord_, args, in_grad_);
258 | }));
259 |
260 | cudaError_t err = cudaGetLastError();
261 | if (err != cudaSuccess) {
262 | printf("error in NdLinearSampleBackwardDataKernel: %s\n", cudaGetErrorString(err));
263 | }
264 | }
265 |
266 |
267 | template
268 | __global__ void NdLinearSampleBackwardCoordKernel(int n,
269 | const scalar_t* out_grad,
270 | const scalar_t* data,
271 | const scalar_t *shape,
272 | const scalar_t* coord,
273 | const SampleArgs args,
274 | scalar_t* coord_grad_c) {
275 |
276 | //const int batch = args.batch;
277 | const int channel = args.channel;
278 | const int spatial_dims = args.spatial_dims;
279 | const int prod_shape = args.prod_shape;
280 |
281 | CUDA_KERNEL_LOOP(thread_id, n) {
282 | const int out_c = thread_id % channel;
283 | const int out_d = (thread_id / channel) % spatial_dims;
284 | const int out_n = (thread_id / channel) / spatial_dims;
285 | coord_grad_c[thread_id] = nd_linear_backward_coord(ldg(out_grad + out_n * channel + out_c), data + out_c * prod_shape, shape, coord + out_n * spatial_dims, spatial_dims, out_d);
286 | }
287 | }
288 |
289 |
290 | void NdLinearSampleBackwardCoord(const at::Tensor out_grad, const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor coord_grad_c) {
291 | int num_kernels = args.batch * args.spatial_dims * args.channel;
292 |
293 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
294 | data.type(), "NdLinearSampleBackwardCoord_GPU", ([&] {
295 | const scalar_t *out_grad_ = out_grad.data();
296 | const scalar_t *data_ = data.data();
297 | const scalar_t *shape_ = shape.data();
298 | const scalar_t *coord_ = coord.data();
299 | scalar_t *coord_grad_c_ = coord_grad_c.data();
300 |
301 | NdLinearSampleBackwardCoordKernel<<>>(
302 | num_kernels, out_grad_, data_, shape_, coord_, args, coord_grad_c_);
303 | }));
304 |
305 | cudaError_t err = cudaGetLastError();
306 | if (err != cudaSuccess) {
307 | printf("error in NdLinearSampleBackwardCoordKernel: %s\n", cudaGetErrorString(err));
308 | }
309 | }
310 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/functions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | from .filter_sample_depthwise import (
11 | sample_depthwise,
12 | deform_sample_depthwise,
13 | )
14 | from .nd_linear_sample import nd_linear_sample
15 |
16 | __all__ = [
17 | 'sample_depthwise',
18 | 'deform_sample_depthwise',
19 | 'nd_linear_sample',
20 | ]
21 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/functions/filter_sample_depthwise.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : filter_sample_depthwise.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import torch
11 | from torch.autograd import Function
12 | from apex import amp
13 |
14 | from .. import filter_sample_depthwise_cuda
15 |
16 | __all__ = ['sample_depthwise', 'deform_sample_depthwise']
17 |
18 |
19 | class SampleDepthwiseFunction(Function):
20 | @staticmethod
21 | def forward(
22 | ctx,
23 | input,
24 | rotation,
25 | weight,
26 | kernel_size=(3, 3),
27 | stride=(1, 1),
28 | padding=(1, 1),
29 | dilation=(1, 1),
30 | rotation_groups=1,
31 | ):
32 | ctx.kernel_size = kernel_size
33 | ctx.stride = stride
34 | ctx.padding = padding
35 | ctx.dilation = dilation
36 | ctx.rotation_groups = rotation_groups
37 |
38 | ctx.save_for_backward(input, rotation, weight)
39 | output = input.new_empty(
40 | SampleDepthwiseFunction._output_size(
41 | input, kernel_size, stride, padding, dilation))
42 |
43 | if not input.is_cuda:
44 | raise NotImplementedError
45 | else:
46 | filter_sample_depthwise_cuda.sample_depthwise_forward_cuda(
47 | input, rotation, weight, output,
48 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
49 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
50 | weight.size(2), weight.size(3), rotation_groups)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, rotation, weight = ctx.saved_tensors
56 |
57 | grad_input = grad_rotation = grad_weight = None
58 |
59 | if not grad_output.is_cuda:
60 | raise NotImplementedError
61 | else:
62 |
63 | if ctx.needs_input_grad[0]:
64 | grad_input = torch.zeros_like(input)
65 | filter_sample_depthwise_cuda.sample_depthwise_backward_data_cuda(
66 | grad_output, input, rotation, weight, grad_input,
67 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
68 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
69 | weight.size(2), weight.size(3), ctx.rotation_groups)
70 |
71 | if ctx.needs_input_grad[1]:
72 | grad_rotation = torch.zeros_like(rotation)
73 | filter_sample_depthwise_cuda.sample_depthwise_backward_rotation_cuda(
74 | grad_output, input, rotation, weight, grad_rotation,
75 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
76 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
77 | weight.size(2), weight.size(3), ctx.rotation_groups)
78 |
79 | if ctx.needs_input_grad[2]:
80 | grad_weight = torch.zeros_like(weight)
81 | filter_sample_depthwise_cuda.sample_depthwise_backward_filter_cuda(
82 | grad_output, input, rotation, weight, grad_weight,
83 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
84 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
85 | weight.size(2), weight.size(3), ctx.rotation_groups)
86 |
87 | return (grad_input, grad_rotation, grad_weight,) + (None,) * 4
88 |
89 | @staticmethod
90 | def _output_size(input, kernel_size, stride, padding, dilation):
91 | output_size = (input.size(0), input.size(1))
92 | for d in range(input.dim() - 2):
93 | in_size = input.size(d + 2)
94 | pad = padding[d]
95 | kernel = dilation[d] * (kernel_size[d] - 1) + 1
96 | stride_ = stride[d]
97 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
98 | if not all(map(lambda s: s > 0, output_size)):
99 | raise ValueError(
100 | "convolution input is too small (output would be {})".format(
101 | 'x'.join(map(str, output_size))))
102 | return output_size
103 |
104 |
105 | class DeformableSampleDepthwiseFunction(Function):
106 | @staticmethod
107 | def forward(
108 | ctx,
109 | input,
110 | offset,
111 | rotation,
112 | weight,
113 | kernel_size=(3, 3),
114 | stride=(1, 1),
115 | padding=(1, 1),
116 | dilation=(1, 1),
117 | rotation_groups=1,
118 | ):
119 | ctx.kernel_size = kernel_size
120 | ctx.stride = stride
121 | ctx.padding = padding
122 | ctx.dilation = dilation
123 | ctx.rotation_groups = rotation_groups
124 |
125 | ctx.save_for_backward(input, offset, rotation, weight)
126 | output = input.new_empty(
127 | DeformableSampleDepthwiseFunction._output_size(
128 | input, kernel_size, stride, padding, dilation))
129 |
130 | if not input.is_cuda:
131 | raise NotImplementedError
132 | else:
133 | filter_sample_depthwise_cuda. \
134 | deformable_sample_depthwise_forward_cuda(
135 | input, offset, rotation, weight, output,
136 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0],
137 | ctx.stride[1], ctx.padding[0], ctx.padding[1],
138 | ctx.dilation[0], ctx.dilation[1], weight.size(2),
139 | weight.size(3), rotation_groups)
140 | return output
141 |
142 | @staticmethod
143 | def backward(ctx, grad_output):
144 | input, offset, rotation, weight = ctx.saved_tensors
145 |
146 | grad_input = grad_offset = grad_rotation = grad_weight = None
147 |
148 | if not grad_output.is_cuda:
149 | raise NotImplementedError
150 | else:
151 | if ctx.needs_input_grad[0]:
152 | grad_input = torch.zeros_like(input)
153 | filter_sample_depthwise_cuda. \
154 | deformable_sample_depthwise_backward_data_cuda(
155 | grad_output, input, offset, rotation, weight,
156 | grad_input, ctx.kernel_size[0], ctx.kernel_size[1],
157 | ctx.stride[0], ctx.stride[1], ctx.padding[0],
158 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
159 | weight.size(2), weight.size(3), ctx.rotation_groups)
160 |
161 | if ctx.needs_input_grad[1]:
162 | grad_offset = torch.zeros_like(offset)
163 | filter_sample_depthwise_cuda. \
164 | deformable_sample_depthwise_backward_offset_cuda(
165 | grad_output, input, offset, rotation, weight,
166 | grad_offset, ctx.kernel_size[0], ctx.kernel_size[1],
167 | ctx.stride[0], ctx.stride[1], ctx.padding[0],
168 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
169 | weight.size(2), weight.size(3), ctx.rotation_groups)
170 |
171 | if ctx.needs_input_grad[2]:
172 | grad_rotation = torch.zeros_like(rotation)
173 | filter_sample_depthwise_cuda. \
174 | deformable_sample_depthwise_backward_rotation_cuda(
175 | grad_output, input, offset, rotation, weight,
176 | grad_rotation, ctx.kernel_size[0], ctx.kernel_size[1],
177 | ctx.stride[0], ctx.stride[1], ctx.padding[0],
178 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
179 | weight.size(2), weight.size(3), ctx.rotation_groups)
180 |
181 | if ctx.needs_input_grad[3]:
182 | grad_weight = torch.zeros_like(weight)
183 | filter_sample_depthwise_cuda. \
184 | deformable_sample_depthwise_backward_filter_cuda(
185 | grad_output, input, offset, rotation, weight,
186 | grad_weight, ctx.kernel_size[0], ctx.kernel_size[1],
187 | ctx.stride[0], ctx.stride[1], ctx.padding[0],
188 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1],
189 | weight.size(2), weight.size(3), ctx.rotation_groups)
190 |
191 | return (grad_input, grad_offset, grad_rotation, grad_weight,) + (None,) * 4
192 |
193 | @staticmethod
194 | def _output_size(input, kernel_size, stride, padding, dilation):
195 | output_size = (input.size(0), input.size(1))
196 | for d in range(input.dim() - 2):
197 | in_size = input.size(d + 2)
198 | pad = padding[d]
199 | kernel = dilation[d] * (kernel_size[d] - 1) + 1
200 | stride_ = stride[d]
201 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
202 | if not all(map(lambda s: s > 0, output_size)):
203 | raise ValueError(
204 | "convolution input is too small (output would be {})".format(
205 | 'x'.join(map(str, output_size))))
206 | return output_size
207 |
208 |
209 | # register as fp32 functions.
210 | sample_depthwise = amp.float_function(SampleDepthwiseFunction.apply)
211 | deform_sample_depthwise = amp.float_function(DeformableSampleDepthwiseFunction.apply)
212 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/functions/nd_linear_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : nd_linear_sample.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import torch
11 | from torch.autograd import Function
12 |
13 | from .. import nd_linear_sample_cuda
14 |
15 | __all__ = ['nd_linear_sample']
16 |
17 |
18 | class NdLinearSampleFunction(Function):
19 | @staticmethod
20 | def forward(ctx, input, coord):
21 | ctx.save_for_backward(input, coord)
22 | shape = torch.tensor(input.shape[1:], dtype=input.dtype, device=input.device)
23 | output = input.new_empty(NdLinearSampleFunction._output_size(input, coord))
24 |
25 | if not input.is_cuda:
26 | raise NotImplementedError
27 | else:
28 | nd_linear_sample_cuda.nd_linear_sample_forward_cuda(
29 | input, shape, coord, output
30 | )
31 | return output
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | input, coord = ctx.saved_tensors
36 | shape = torch.tensor(input.shape[1:], dtype=input.dtype, device=input.device)
37 | grad_input = grad_coord = None
38 |
39 | if not grad_output.is_cuda:
40 | raise NotImplementedError
41 | else:
42 | if ctx.needs_input_grad[0]:
43 | grad_input = input.new_empty(*input.size())
44 | nd_linear_sample_cuda.nd_linear_sample_backward_data_cuda(
45 | grad_output, input, shape, coord, grad_input
46 | )
47 |
48 | if ctx.needs_input_grad[1]:
49 | grad_coord_c = coord.new_empty(
50 | coord.size(0), coord.size(1), input.size(0)
51 | )
52 | nd_linear_sample_cuda.nd_linear_sample_backward_coord_cuda(
53 | grad_output, input, shape, coord, grad_coord_c
54 | )
55 | grad_coord = grad_coord_c.sum(2)
56 |
57 | return (grad_input, grad_coord)
58 |
59 | @staticmethod
60 | def _output_size(input, coord):
61 | output_size = (coord.size(0), input.size(0))
62 | return output_size
63 |
64 |
65 | nd_linear_sample = NdLinearSampleFunction.apply
66 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | from .filter_sample_depthwise import (
11 | SampleDepthwise,
12 | DeformableSampleDepthwise,
13 | )
14 |
15 | __all__ = [
16 | 'SampleDepthwise',
17 | 'DeformableSampleDepthwise',
18 | ]
19 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/modules/filter_sample_depthwise.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : filter_sample_depthwise.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import math
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.nn.modules.utils import _pair
15 |
16 | from ..functions.filter_sample_depthwise import (
17 | sample_depthwise,
18 | deform_sample_depthwise,
19 | )
20 |
21 | __all__ = [
22 | 'SampleDepthwise',
23 | 'DeformableSampleDepthwise',
24 | ]
25 |
26 |
27 | class SampleDepthwise(nn.Module):
28 | def __init__(self,
29 | scope_size,
30 | in_channels,
31 | out_channels,
32 | kernel_size=3,
33 | stride=1,
34 | padding=1,
35 | dilation=1,
36 | groups=1,
37 | rotation_groups=1,
38 | bias=False):
39 | super(SampleDepthwise, self).__init__()
40 | self.in_channels = in_channels
41 | assert in_channels == out_channels and groups == in_channels
42 | assert in_channels % rotation_groups == 0 and \
43 | out_channels % rotation_groups == 0
44 | self.scope_size = scope_size
45 | self.kernel_size = _pair(kernel_size)
46 | self.stride = _pair(stride)
47 | self.padding = _pair(padding)
48 | self.dilation = _pair(dilation)
49 | self.rotation_groups = rotation_groups
50 | self.bias = bias
51 |
52 | self.weight = nn.Parameter(torch.Tensor(self.in_channels, 1, *self.scope_size))
53 | if not self.bias:
54 | self.bias = None
55 | else:
56 | self.bias = nn.Parameter(torch.Tensor(self.in_channels))
57 | self.reset_parameters()
58 |
59 | def reset_parameters(self):
60 | n = 1
61 | for k in self.kernel_size:
62 | n *= k
63 | stdv = 1. / math.sqrt(n)
64 | self.weight.data.uniform_(-stdv, stdv)
65 | if self.bias is not None:
66 | self.bias.data.zero_()
67 |
68 | def forward(self, input, rotation=None):
69 | if rotation is None:
70 | output_size = self._output_size(input, self.weight)
71 | rotation = input.new_zeros(
72 | input.size(0),
73 | self.rotation_groups * self.kernel_size[0] *
74 | self.kernel_size[1] * 2,
75 | output_size[2],
76 | output_size[3])
77 | out = sample_depthwise(
78 | input, rotation, self.weight, self.kernel_size, self.stride,
79 | self.padding, self.dilation, self.rotation_groups)
80 | if self.bias is not None:
81 | out += self.bias.view(1, self.in_channels, 1, 1)
82 | return out
83 |
84 | def _output_size(self, input, weight):
85 | channels = weight.size(0)
86 |
87 | output_size = (input.size(0), channels)
88 | for d in range(input.dim() - 2):
89 | in_size = input.size(d + 2)
90 | pad = self.padding[d]
91 | kernel = self.dilation[d] * (self.kernel_size[d] - 1) + 1
92 | stride = self.stride[d]
93 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
94 | if not all(map(lambda s: s > 0, output_size)):
95 | raise ValueError(
96 | "convolution input is too small (output would be {})".format(
97 | 'x'.join(map(str, output_size))))
98 | return output_size
99 |
100 | def extra_repr(self):
101 | s = ('scope_size={scope_size}, in_channels={in_channels}, '
102 | 'kernel_size={kernel_size}, stride={stride}')
103 | if self.padding != (0,) * len(self.padding):
104 | s += ', padding={padding}'
105 | if self.dilation != (1,) * len(self.dilation):
106 | s += ', dilation={dilation}'
107 | if self.rotation_groups != 1:
108 | s += ', rotation_groups={rotation_groups}'
109 | if self.bias is None:
110 | s += ', bias=False'
111 | return s.format(**self.__dict__)
112 |
113 |
114 | class DeformableSampleDepthwise(nn.Module):
115 | def __init__(self,
116 | scope_size,
117 | in_channels,
118 | out_channels,
119 | kernel_size=3,
120 | stride=1,
121 | padding=1,
122 | dilation=1,
123 | groups=1,
124 | rotation_groups=1,
125 | bias=False):
126 | super(DeformableSampleDepthwise, self).__init__()
127 | self.in_channels = in_channels
128 | assert in_channels == out_channels and groups == in_channels
129 | assert in_channels % rotation_groups == 0 and \
130 | out_channels % rotation_groups == 0
131 | self.scope_size = scope_size
132 | self.kernel_size = _pair(kernel_size)
133 | self.stride = _pair(stride)
134 | self.padding = _pair(padding)
135 | self.dilation = _pair(dilation)
136 | self.rotation_groups = rotation_groups
137 | self.bias = bias
138 |
139 | self.weight = nn.Parameter(torch.Tensor(self.in_channels, 1, *self.scope_size))
140 | if not self.bias:
141 | self.bias = None
142 | else:
143 | self.bias = nn.Parameter(torch.Tensor(self.in_channels))
144 | self.reset_parameters()
145 |
146 | def reset_parameters(self):
147 | n = 1
148 | for k in self.kernel_size:
149 | n *= k
150 | stdv = 1. / math.sqrt(n)
151 | self.weight.data.uniform_(-stdv, stdv)
152 | if self.bias is not None:
153 | self.bias.data.zero_()
154 |
155 | def forward(self, input, offset=None, rotation=None):
156 | if offset is None:
157 | output_size = self._output_size(input, self.weight)
158 | offset = input.new_zeros(
159 | input.size(0),
160 | self.kernel_size[0] * self.kernel_size[1] * 2,
161 | output_size[2],
162 | output_size[3])
163 | if rotation is None:
164 | output_size = self._output_size(input, self.weight)
165 | rotation = input.new_zeros(
166 | input.size(0),
167 | self.rotation_groups * self.kernel_size[0] *
168 | self.kernel_size[1] * 2,
169 | output_size[2],
170 | output_size[3])
171 |
172 | out = deform_sample_depthwise(
173 | input, offset, rotation, self.weight, self.kernel_size,
174 | self.stride, self.padding, self.dilation, self.rotation_groups)
175 | if self.bias is not None:
176 | out += self.bias.view(1, self.in_channels, 1, 1)
177 | return out
178 |
179 | def _output_size(self, input, weight):
180 | channels = weight.size(0)
181 |
182 | output_size = (input.size(0), channels)
183 | for d in range(input.dim() - 2):
184 | in_size = input.size(d + 2)
185 | pad = self.padding[d]
186 | kernel = self.dilation[d] * (self.kernel_size[d] - 1) + 1
187 | stride = self.stride[d]
188 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
189 | if not all(map(lambda s: s > 0, output_size)):
190 | raise ValueError(
191 | "convolution input is too small (output would be {})".format(
192 | 'x'.join(map(str, output_size))))
193 | return output_size
194 |
195 | def extra_repr(self):
196 | s = ('scope_size={scope_size}, in_channels={in_channels}, '
197 | 'kernel_size={kernel_size}, stride={stride}')
198 | if self.padding != (0,) * len(self.padding):
199 | s += ', padding={padding}'
200 | if self.dilation != (1,) * len(self.dilation):
201 | s += ', dilation={dilation}'
202 | if self.rotation_groups != 1:
203 | s += ', rotation_groups={rotation_groups}'
204 | if self.bias is None:
205 | s += ', bias=False'
206 | return s.format(**self.__dict__)
207 |
--------------------------------------------------------------------------------
/deformable_kernels/ops/deform_kernel/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : setup.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | # Date : 01/17/2020
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | from setuptools import setup
11 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
12 |
13 | setup(
14 | name='filter_sample_depthwise',
15 | ext_modules=[
16 | CUDAExtension(
17 | 'filter_sample_depthwise_cuda',
18 | [
19 | 'csrc/filter_sample_depthwise_cuda.cpp',
20 | 'csrc/filter_sample_depthwise_cuda_kernel.cu',
21 | ]
22 | ),
23 | ],
24 | cmdclass={'build_ext': BuildExtension}
25 | )
26 |
27 | setup(
28 | name="nd_linear_sample",
29 | ext_modules=[
30 | CUDAExtension(
31 | "nd_linear_sample_cuda",
32 | [
33 | "csrc/nd_linear_sample_cuda.cpp",
34 | "csrc/nd_linear_sample_cuda_kernel.cu",
35 | ],
36 | )
37 | ],
38 | cmdclass={"build_ext": BuildExtension},
39 | )
40 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: deformable-kernels
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.6
7 | - pytorch=1.3.1
8 | - torchvision=0.4.2
9 | - cudatoolkit=10
10 |
--------------------------------------------------------------------------------