├── .gitignore ├── LICENSE ├── README.md ├── deformable_kernels ├── __init__.py ├── modules │ ├── __init__.py │ ├── cond_conv.py │ ├── deform_conv.py │ └── deform_kernel.py └── ops │ ├── __init__.py │ └── deform_kernel │ ├── __init__.py │ ├── csrc │ ├── filter_sample_depthwise_cuda.cpp │ ├── filter_sample_depthwise_cuda.h │ ├── filter_sample_depthwise_cuda_kernel.cu │ ├── nd_linear_sample_cuda.cpp │ ├── nd_linear_sample_cuda.h │ └── nd_linear_sample_cuda_kernel.cu │ ├── functions │ ├── __init__.py │ ├── filter_sample_depthwise.py │ └── nd_linear_sample.py │ ├── modules │ ├── __init__.py │ └── filter_sample_depthwise.py │ └── setup.py └── environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Global gitignore file 2 | # Adapted from https://github.com/cowboy/dotfiles/blob/master/link/.gitignore_global 3 | # Adatped from https://github.com/wookayin/dotfiles/blob/master/git/gitignore#L14 4 | 5 | # Direnv stuffs # 6 | .direnv 7 | .envrc 8 | 9 | # Compiled source # 10 | *.class 11 | *.dll 12 | *.exe 13 | *.o 14 | *.so 15 | *.pyc 16 | **/__pycache__/ 17 | 18 | # Packages # 19 | # It's better to unpack these files and commit the raw source because 20 | # git has its own built in compression methods. 21 | *.7z 22 | *.jar 23 | *.rar 24 | *.zip 25 | *.gz 26 | *.bzip 27 | *.xz 28 | *.lzma 29 | 30 | # packing-only formats 31 | *.iso 32 | *.tar 33 | 34 | # package management formats 35 | *.dmg 36 | *.xpi 37 | *.gem 38 | *.egg 39 | *.egg-info 40 | *.deb 41 | *.rpm 42 | 43 | # Logs and databases # 44 | *.log 45 | *.sqlite 46 | 47 | # OS generated files # 48 | .DS_Store 49 | .Spotlight-V100 50 | .Trashes 51 | ._* 52 | 53 | # Linux 54 | .fuse_hidden* 55 | .nfs* 56 | 57 | # Windows image file caches 58 | Thumbs.db 59 | 60 | # Folder config file 61 | Desktop.ini 62 | 63 | # Vim 64 | .*.s[a-w][a-z] 65 | 66 | # IDE stuffs 67 | .idea/ 68 | *.iml 69 | .project 70 | .classpath 71 | .settings/ 72 | .ipynb_checkpoints/ 73 | 74 | # Tex 75 | *.aux 76 | *.bbl 77 | *.blg 78 | *.brf 79 | *.fdb_latexmk 80 | *.fls 81 | *.log 82 | *.out 83 | *.pdf 84 | *.synctex.gz 85 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hang Gao, Xizhou Zhu, Steve Lin, Jifeng Dai. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deformable Kernels [[ICLR 2020]](https://arxiv.org/abs/1910.02940) [[Website]](https://people.eecs.berkeley.edu/~hangg/deformable-kernels/) 2 | 3 | 4 | 5 | **Deformable Kernels: Adapting Effective Receptive Fields for Object Deformation**
6 | [Hang Gao*](http://people.eecs.berkeley.edu/~hangg/), [Xizhou 7 | Zhu*](https://scholar.google.com/citations?user=02RXI00AAAAJ&hl=en&oi=ao), [Steve Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en&oi=ao), [Jifeng Dai](https://jifengdai.org/).
8 | In [ICLR, 2020](https://arxiv.org/abs/1910.02940). 9 | 10 | This repository contains official implementation of deformable kernels.
11 | 12 | **Table of contents**
13 | 1. [Customized operators for deformable kernels, along with its variants.](#1-customized-operators)
14 | 2. [Instructions to use our operators.](#2-quickstart)
15 | 3. [Results on ImageNet & COCO benchmarks, with pretrained models for 16 | reproduction.](#3-results--pretrained-models)
17 | 5. [Training and evaluation code.](#4-training--evaluation-code)
18 | 19 | ## (0) Getting started 20 | 21 | ### PyTorch 22 | - Get [CUDA 10.1](https://developer.nvidia.com/cuda-10.1-download-archive-base) 23 | installed on your machine. 24 | - Install PyTorch ([pytorch.org](http://pytorch.org)). 25 | - `conda env create -f environment.yml`. 26 | 27 | ### Apex 28 | - Install [Apex](https://github.com/NVIDIA/apex/) from its official repo. This 29 | will require CUDA 10.1 to work with the latest pytorch version (which is 30 | `pytorch=1.3.1` as being tested against). It is used for fast mix-precision 31 | inference and should work out of the box. 32 | 33 | ### Compile our operators 34 | ```bash 35 | # assume at project root 36 | ( 37 | cd deformable_kernels/ops/deform_kernel; 38 | pip install -e .; 39 | ) 40 | ``` 41 | 42 | 43 | ## (1) Customized operators 44 | 45 | 46 | 47 | This repo includes all deformable kernel variants described in our paper, namely: 48 | 49 | - Global Deformable Kernels; 50 | - Local Deformable Kernels; 51 | - Local Deformable Kernels integrating with Deformable Convolutions; 52 | 53 | Instead of learning offsets on image space, we propose to deform and resample 54 | on kernel space. This enables powerful dynamic inference capacity. For more 55 | technical details, please refer to their 56 | [definitions](deformable_kernels/modules/deform_kernel.py). 57 | 58 | We also provide implementations on our rivalries, namely: 59 | 60 | - [Deformable Convolutions](https://arxiv.org/abs/1703.06211); 61 | - [Soft Conditional Computation](https://arxiv.org/abs/1904.04971); 62 | 63 | Please refer to their module definitions under `deformable_kernels/modules` folder. 64 | 65 | 66 | ## (2) Quickstart 67 | The following snippet constructs the deformable kernels we used for our experiments 68 | 69 | ```python 70 | from deformable_kernels.modules import ( 71 | GlobalDeformKernel2d, 72 | DeformKernel2d, 73 | DeformKernelConv2d, 74 | ) 75 | 76 | # global DK with scope size 2, kernel size 1, stride 1, padding 0, depthwise convolution. 77 | gdk = GlobalDeformKernel2d((2, 2), [inplanes], [inplanes], groups=[inplanes]) 78 | # (local) DK with scope size 4, kernel size 3, stride 1, padding 1, depthwise convolution. 79 | dk = DeformKernel2d((4, 4), [inplanes], [inplanes], 3, 1, 1, groups=[inplanes]) 80 | # (local) DK integrating with dcn, with kernel & image offsets separately learnt. 81 | dkc = DeformKernelConv2d((4, 4), [inplanes], [inplanes], 3, 1, 1, groups=[inplanes]). 82 | ``` 83 | 84 | Note that all of our customized operators only support depthwise convolutions 85 | now, mainly because that efficiently resampling kernels at runtime is 86 | extremely slow if we orthogonally compute over each channel. We are trying to 87 | loose this requirement by iterating our CUDA implementation. Any contribuitions 88 | are welcome! 89 | 90 | 91 | ## (3) Results & pretrained models 92 | Under construction. 93 | 94 | 95 | ## (4) Training & evaluation code 96 | Under construction. 97 | 98 | 99 | ## (A) License 100 | This project is released under the [MIT license](LICENSE). 101 | 102 | 103 | ## (B) Citation & Contact 104 | If you find this repo useful for your research, please consider citing this 105 | bibtex: 106 | 107 | ```tex 108 | @article{gao2019deformable, 109 | title={Deformable Kernels: Adapting Effective Receptive Fields for Object Deformation}, 110 | author={Gao, Hang and Zhu, Xizhou and Lin, Steve and Dai, Jifeng}, 111 | journal={arXiv preprint arXiv:1910.02940}, 112 | year={2019} 113 | } 114 | ``` 115 | 116 | Please contact Hang Gao `` and Xizhou Zhu 117 | `` with any comments or feedback. 118 | -------------------------------------------------------------------------------- /deformable_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 12/15/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | -------------------------------------------------------------------------------- /deformable_kernels/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 12/26/2019 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | from .cond_conv import CondConv2d 11 | from .deform_conv import DeformConv2d 12 | from .deform_kernel import ( 13 | GlobalDeformKernel2d, 14 | LocalDeformKernel2d, 15 | DeformKernel2d, 16 | DeformKernelConv2d, 17 | ) 18 | 19 | __all__ = [ 20 | 'CondConv2d', 21 | 'DeformConv2d', 22 | 'GlobalDeformKernel2d', 23 | 'LocalDeformKernel2d', 24 | 'DeformKernel2d', 25 | 'DeformKernelConv2d', 26 | ] 27 | -------------------------------------------------------------------------------- /deformable_kernels/modules/cond_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : cond_conv.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 12/25/2019 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import torch 11 | 12 | from apex import amp 13 | from torch import nn 14 | 15 | 16 | class CondConv2d(nn.Module): 17 | 18 | def __init__(self, num_experts, in_channels, out_channels, kernel_size, 19 | stride=1, padding=0, dilation=1, groups=1, bias=True, 20 | padding_mode='zeros'): 21 | super().__init__() 22 | self.num_experts = num_experts 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.kernel_size = kernel_size 26 | self.stride = stride 27 | self.dilation = dilation 28 | self.groups = groups 29 | self.padding_mode = padding_mode 30 | assert not bias 31 | 32 | self.weight = nn.Parameter( 33 | torch.tensor( 34 | num_experts * out_channels, 35 | in_channels // self.groups, 36 | kernel_size, 37 | kernel_size, 38 | ) 39 | ) 40 | self.fc = nn.Linear(in_channels, num_experts) 41 | self.fc.zero_init = True 42 | 43 | @amp.float_function 44 | def dynaic_inference(self, x, weight): 45 | # TODO(Hang Gao @ 12/26): make sure passing weight to amp is necessary. 46 | n = x.shape[0] 47 | 48 | avg_x = x.mean((2, 3)) 49 | gate_x = torch.sigmoid(self.fc(avg_x)) 50 | 51 | weight = torch.mm( 52 | gate_x, 53 | self.weight.reshape(self.num_experts, -1) 54 | ).reshape( 55 | n * self.out_channels, 56 | self.in_channels // self.groups, 57 | self.kernel_size, 58 | self.kernel_size, 59 | ) 60 | return weight 61 | 62 | def forward(self, x): 63 | n, _, h, w = x.shape 64 | weight = self.dynaic_inference(x, self.weight) 65 | 66 | out = nn.functional.conv2d( 67 | x.reshape(1, n * self.in_channels, h, w), 68 | weight, 69 | stride=self.stride, 70 | padding=self.padding, 71 | dilation=self.dilation, 72 | groups=n*self.groups, 73 | padding_mode=self.padding_mode, 74 | ) 75 | out = out.reshape(n, self.out_channels, *out.shape[2:]) 76 | return out 77 | 78 | def extra_repr(self): 79 | s = ('{num_experts}, {in_channels}, {out_channels}' 80 | ', kernel_size={kernel_size}, stride={stride}' 81 | ', scale={scale}, zero_point={zero_point}') 82 | if self.padding != (0,) * len(self.padding): 83 | s += ', padding={padding}' 84 | if self.dilation != (1,) * len(self.dilation): 85 | s += ', dilation={dilation}' 86 | if self.groups != 1: 87 | s += ', groups={groups}' 88 | return s.format(**self.__dict__) 89 | -------------------------------------------------------------------------------- /deformable_kernels/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : deform_conv.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | from .deform_kernel import DeformKernelConv2d 11 | 12 | __all__ = ['DeformConv2d'] 13 | 14 | 15 | class DeformConv2d(DeformKernelConv2d): 16 | """ 17 | Depthwise deformable convolution. 18 | """ 19 | def __init__( 20 | self, 21 | in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=1, 25 | padding=0, 26 | dilation=1, 27 | groups=1, 28 | bias=False, 29 | offset_clip=None, 30 | ): 31 | super().__init__( 32 | (kernel_size, kernel_size), 33 | in_planes, 34 | out_planes, 35 | kernel_size, 36 | stride, 37 | padding, 38 | dilation, 39 | groups, 40 | bias, 41 | offset_clip, 42 | ) 43 | -------------------------------------------------------------------------------- /deformable_kernels/modules/deform_kernel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : deform_kernel.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import torch 11 | from torch import nn 12 | from apex import amp 13 | 14 | from ..ops.deform_kernel.functions import nd_linear_sample 15 | from ..ops.deform_kernel.modules import ( 16 | SampleDepthwise, 17 | DeformableSampleDepthwise, 18 | ) 19 | 20 | __all__ = [ 21 | 'GlobalDeformKernel2d', 22 | 'LocalDeformKernel2d', 23 | 'DeformKernel2d', 24 | 'DeformKernelConv2d', 25 | ] 26 | 27 | 28 | class GlobalDeformKernel2d(nn.Module): 29 | 30 | def __init__( 31 | self, 32 | weight_shape, 33 | in_planes, 34 | out_planes, 35 | kernel_size=1, 36 | stride=1, 37 | padding=0, 38 | dilation=1, 39 | groups=1, 40 | bias=False, 41 | ): 42 | super().__init__() 43 | self.kernel_size = kernel_size 44 | self.weight_shape = weight_shape 45 | self.weight_dilate = 1 46 | self.stride = stride 47 | self.padding = padding 48 | self.dilation = dilation 49 | self.out_planes = out_planes 50 | self.in_planes = in_planes 51 | self.group = groups 52 | assert not bias 53 | 54 | self.weight = nn.Parameter( 55 | torch.Tensor(out_planes, in_planes // self.group, *self.weight_shape) 56 | ) 57 | self.fc = nn.Linear( 58 | in_planes, kernel_size * kernel_size * len(self.weight_shape) 59 | ) 60 | self.fc.zero_init = True 61 | 62 | assert len(self.weight_shape) >= 2 63 | 64 | start_h = (weight_shape[0] - (kernel_size - 1) * self.weight_dilate - 1) / 2.0 65 | start_w = (weight_shape[1] - (kernel_size - 1) * self.weight_dilate - 1) / 2.0 66 | self.fc_bias = [] 67 | for h in range(kernel_size): 68 | for w in range(kernel_size): 69 | self.fc_bias += [ 70 | start_h + h * self.weight_dilate, 71 | start_w + w * self.weight_dilate, 72 | ] 73 | for i in range(len(self.weight_shape) - 2): 74 | self.fc_bias += [(self.weight_shape[i + 2] - 1) / 2.0] 75 | 76 | @amp.float_function 77 | def dynamic_weight(self, x, weight): 78 | n, c, h, w = x.shape 79 | avg_x = x.view(n, c, -1).mean(2) 80 | coord = self.fc(avg_x) * self.weight_dilate + torch.tensor( 81 | self.fc_bias, dtype=x.dtype, device=x.device 82 | ).unsqueeze(0) 83 | coord = torch.clamp(coord, 0, self.weight_shape[0] - 1) 84 | 85 | weight = weight.view( 86 | self.out_planes * self.in_planes // self.group, *self.weight_shape 87 | ) 88 | coord = coord.view( 89 | n * self.kernel_size * self.kernel_size, len(self.weight_shape) 90 | ) 91 | 92 | weight_sample = nd_linear_sample(weight, coord).view( 93 | n, 94 | self.kernel_size * self.kernel_size, 95 | self.out_planes * self.in_planes // self.group, 96 | ) 97 | weight = weight_sample.transpose(1, 2).reshape( 98 | n * self.out_planes, 99 | self.in_planes // self.group, 100 | self.kernel_size, 101 | self.kernel_size, 102 | ) 103 | return weight 104 | 105 | def forward(self, x): 106 | n, c, h, w = x.shape 107 | weight = self.dynamic_weight(x, self.weight) 108 | 109 | out = nn.functional.conv2d( 110 | x.view(1, n * c, h, w), 111 | weight, 112 | stride=self.stride, 113 | padding=self.padding, 114 | dilation=self.dilation, 115 | groups=n * self.group, 116 | ) 117 | out = out.view(n, self.out_planes, out.shape[2], out.shape[3]) 118 | return out 119 | 120 | def extra_repr(self): 121 | s = ( 122 | "{in_planes}, {out_planes}, weight_shape={weight_shape}, " 123 | "kernel_size={kernel_size}, stride={stride}, " 124 | "weight_dilate={weight_dilate}" 125 | ) 126 | if self.padding != 0: 127 | s += ", padding={padding}" 128 | if self.dilation != 1: 129 | s += ", dilation={dilation}" 130 | if self.group != 1: 131 | s += ", group={group}" 132 | return s.format(**self.__dict__) 133 | 134 | 135 | class LocalDeformKernel2d(nn.Module): 136 | 137 | def __init__( 138 | self, 139 | weight_shape, 140 | in_planes, 141 | out_planes, 142 | kernel_size=1, 143 | stride=1, 144 | padding=0, 145 | dilation=1, 146 | groups=1, 147 | rotation_groups=1, 148 | bias=False, 149 | rotation_clip=None, 150 | ): 151 | super().__init__() 152 | self.kernel_size = kernel_size 153 | self.weight_shape = weight_shape 154 | self.weight_dilate = 1 155 | self.stride = stride 156 | self.padding = padding 157 | self.dilation = dilation 158 | self.out_planes = out_planes 159 | self.in_planes = in_planes 160 | self.rotation_groups = rotation_groups 161 | self.group = groups 162 | assert not bias 163 | assert len(self.weight_shape) >= 2 164 | 165 | self.rotation_conv = nn.Conv2d( 166 | in_planes, rotation_groups * kernel_size * kernel_size * 2, 167 | kernel_size, stride, padding, dilation, bias=True 168 | ) 169 | self.rotation_conv.zero_init = True 170 | self.rotation_clip = rotation_clip 171 | 172 | self.inner_conv = SampleDepthwise( 173 | weight_shape, 174 | in_planes, 175 | out_planes, 176 | kernel_size=kernel_size, 177 | stride=stride, 178 | padding=padding, 179 | dilation=dilation, 180 | groups=groups, 181 | rotation_groups=rotation_groups, 182 | bias=bias, 183 | ) 184 | 185 | def _clip_rotation(self, rotation): 186 | if isinstance(self.rotation_clip, tuple): 187 | return rotation.clamp(**self.rotation_clip) 188 | elif self.rotation_clip == 'scope': 189 | if not hasattr(self, 'fc_bias'): 190 | start_h = (self.weight_shape[0] - (self.kernel_size - 1) * 191 | self.weight_dilate - 1) / 2.0 192 | start_w = (self.weight_shape[1] - (self.kernel_size - 1) * 193 | self.weight_dilate - 1) / 2.0 194 | fc_bias = [] 195 | for h in range(self.kernel_size): 196 | for w in range(self.kernel_size): 197 | fc_bias += [ 198 | start_h + h * self.weight_dilate, 199 | start_w + w * self.weight_dilate, 200 | ] 201 | for i in range(len(self.weight_shape) - 2): 202 | fc_bias += [(self.weight_shape[i + 2] - 1) / 2] 203 | self.fc_bias = rotation.new_tensor(fc_bias) \ 204 | .repeat(self.rotation_groups)[None, :, None, None] 205 | coord = (rotation * self.weight_dilate + self.fc_bias).clamp( 206 | 0, self.weight_shape[0] - 1) 207 | return (coord - self.fc_bias) / self.weight_dilate 208 | else: 209 | raise NotImplementedError( 210 | f'Expect rotation_clip to be tuple or "scope", ' 211 | f'but get {self.rotation_clip}' 212 | ) 213 | 214 | def forward(self, x): 215 | rotation = self.rotation_conv(x) 216 | if self.rotation_clip is not None: 217 | rotation = self._clip_rotation(rotation) 218 | rotation *= self.weight_dilate 219 | out = self.inner_conv(x, rotation) 220 | 221 | return out 222 | 223 | 224 | # refer to local deformable kernel as the default. 225 | DeformKernel2d = LocalDeformKernel2d 226 | 227 | 228 | class DeformKernelConv2d(nn.Module): 229 | 230 | def __init__( 231 | self, 232 | weight_shape, 233 | in_planes, 234 | out_planes, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | dilation=1, 239 | groups=1, 240 | bias=False, 241 | offset_clip=None, 242 | ): 243 | super().__init__() 244 | self.kernel_size = kernel_size 245 | self.stride = stride 246 | self.padding = padding 247 | self.dilation = dilation 248 | self.out_planes = out_planes 249 | self.in_planes = in_planes 250 | self.group = groups 251 | assert not bias 252 | 253 | self.offset_conv = nn.Conv2d( 254 | in_planes, kernel_size * kernel_size * 2, 255 | kernel_size, stride, padding, dilation, bias=True 256 | ) 257 | self.offset_conv.zero_init = True 258 | self.offset_clip = offset_clip 259 | 260 | self.inner_conv = DeformableSampleDepthwise( 261 | weight_shape, 262 | in_planes, 263 | out_planes, 264 | kernel_size=kernel_size, 265 | stride=stride, 266 | padding=padding, 267 | dilation=dilation, 268 | groups=groups, 269 | bias=bias, 270 | ) 271 | 272 | def forward(self, x): 273 | offset = self.offset_conv(x) 274 | if self.offset_clip is not None: 275 | offset = offset.clamp(**self.offset_clip) 276 | offset *= self.dilation 277 | 278 | rotation = None 279 | out = self.inner_conv(x, offset, rotation) 280 | 281 | return out 282 | -------------------------------------------------------------------------------- /deformable_kernels/ops/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 12/25/2019 7 | # 8 | # Distributed under terms of the MIT license. 9 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "filter_sample_depthwise_cuda.h" 7 | 8 | int sample_depthwise_forward_cuda( 9 | at::Tensor input, 10 | at::Tensor rotation, 11 | at::Tensor filter, 12 | at::Tensor output, 13 | int kH, 14 | int kW, 15 | int dH, 16 | int dW, 17 | int padH, 18 | int padW, 19 | int dilationH, 20 | int dilationW, 21 | int scopeH, 22 | int scopeW, 23 | int groupS) { 24 | 25 | input = input.contiguous(); 26 | rotation = rotation.contiguous(); 27 | filter = filter.contiguous(); 28 | 29 | SampleDepthwiseArgs sdw_args; 30 | sdw_args.batch = input.size(0); 31 | sdw_args.channel = input.size(1); 32 | sdw_args.in_height = input.size(2); 33 | sdw_args.in_width = input.size(3); 34 | sdw_args.filter_height = kH; 35 | sdw_args.filter_width = kW; 36 | sdw_args.stride_height = dH; 37 | sdw_args.stride_width = dW; 38 | sdw_args.pad_height = padH; 39 | sdw_args.pad_width = padW; 40 | sdw_args.dilation_height = dilationH; 41 | sdw_args.dilation_width = dilationW; 42 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 43 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 44 | sdw_args.scope_height = scopeH; 45 | sdw_args.scope_width = scopeW; 46 | sdw_args.sampling_group = groupS; 47 | 48 | output = output.view({ 49 | sdw_args.batch, 50 | sdw_args.channel, 51 | sdw_args.out_height, 52 | sdw_args.out_width}); 53 | 54 | SampleDepthwiseConv2dForward( 55 | input, 56 | rotation, 57 | filter, 58 | sdw_args, 59 | output); 60 | 61 | return 1; 62 | } 63 | 64 | int sample_depthwise_backward_data_cuda( 65 | at::Tensor gradOutput, 66 | at::Tensor input, 67 | at::Tensor rotation, 68 | at::Tensor filter, 69 | at::Tensor gradInput, 70 | int kH, 71 | int kW, 72 | int dH, 73 | int dW, 74 | int padH, 75 | int padW, 76 | int dilationH, 77 | int dilationW, 78 | int scopeH, 79 | int scopeW, 80 | int groupS) { 81 | 82 | gradOutput = gradOutput.contiguous(); 83 | rotation = rotation.contiguous(); 84 | filter = filter.contiguous(); 85 | 86 | SampleDepthwiseArgs sdw_args; 87 | sdw_args.batch = input.size(0); 88 | sdw_args.channel = input.size(1); 89 | sdw_args.in_height = input.size(2); 90 | sdw_args.in_width = input.size(3); 91 | sdw_args.filter_height = kH; 92 | sdw_args.filter_width = kW; 93 | sdw_args.stride_height = dH; 94 | sdw_args.stride_width = dW; 95 | sdw_args.pad_height = padH; 96 | sdw_args.pad_width = padW; 97 | sdw_args.dilation_height = dilationH; 98 | sdw_args.dilation_width = dilationW; 99 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 100 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 101 | sdw_args.scope_height = scopeH; 102 | sdw_args.scope_width = scopeW; 103 | sdw_args.sampling_group = groupS; 104 | 105 | gradInput = gradInput.view({ 106 | sdw_args.batch, 107 | sdw_args.channel, 108 | sdw_args.in_height, 109 | sdw_args.in_width}); 110 | 111 | SampleDepthwiseConv2dBackwardData( 112 | gradOutput, 113 | rotation, 114 | filter, 115 | sdw_args, 116 | gradInput); 117 | 118 | return 1; 119 | } 120 | 121 | int sample_depthwise_backward_filter_cuda( 122 | at::Tensor gradOutput, 123 | at::Tensor input, 124 | at::Tensor rotation, 125 | at::Tensor filter, 126 | at::Tensor gradFilter, 127 | int kH, 128 | int kW, 129 | int dH, 130 | int dW, 131 | int padH, 132 | int padW, 133 | int dilationH, 134 | int dilationW, 135 | int scopeH, 136 | int scopeW, 137 | int groupS) { 138 | 139 | gradOutput = gradOutput.contiguous(); 140 | input = input.contiguous(); 141 | rotation = rotation.contiguous(); 142 | 143 | SampleDepthwiseArgs sdw_args; 144 | sdw_args.batch = input.size(0); 145 | sdw_args.channel = input.size(1); 146 | sdw_args.in_height = input.size(2); 147 | sdw_args.in_width = input.size(3); 148 | sdw_args.filter_height = kH; 149 | sdw_args.filter_width = kW; 150 | sdw_args.stride_height = dH; 151 | sdw_args.stride_width = dW; 152 | sdw_args.pad_height = padH; 153 | sdw_args.pad_width = padW; 154 | sdw_args.dilation_height = dilationH; 155 | sdw_args.dilation_width = dilationW; 156 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 157 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 158 | sdw_args.scope_height = scopeH; 159 | sdw_args.scope_width = scopeW; 160 | sdw_args.sampling_group = groupS; 161 | 162 | gradFilter = gradFilter.view({ 163 | sdw_args.channel, 164 | 1, 165 | sdw_args.scope_height, 166 | sdw_args.scope_width}); 167 | 168 | SampleDepthwiseConv2dBackwardFilter( 169 | gradOutput, 170 | input, 171 | rotation, 172 | sdw_args, 173 | gradFilter); 174 | 175 | return 1; 176 | } 177 | 178 | int sample_depthwise_backward_rotation_cuda( 179 | at::Tensor gradOutput, 180 | at::Tensor input, 181 | at::Tensor rotation, 182 | at::Tensor filter, 183 | at::Tensor gradRotation, 184 | int kH, 185 | int kW, 186 | int dH, 187 | int dW, 188 | int padH, 189 | int padW, 190 | int dilationH, 191 | int dilationW, 192 | int scopeH, 193 | int scopeW, 194 | int groupS) { 195 | 196 | gradOutput = gradOutput.contiguous(); 197 | input = input.contiguous(); 198 | rotation = rotation.contiguous(); 199 | filter = filter.contiguous(); 200 | 201 | SampleDepthwiseArgs sdw_args; 202 | sdw_args.batch = input.size(0); 203 | sdw_args.channel = input.size(1); 204 | sdw_args.in_height = input.size(2); 205 | sdw_args.in_width = input.size(3); 206 | sdw_args.filter_height = kH; 207 | sdw_args.filter_width = kW; 208 | sdw_args.stride_height = dH; 209 | sdw_args.stride_width = dW; 210 | sdw_args.pad_height = padH; 211 | sdw_args.pad_width = padW; 212 | sdw_args.dilation_height = dilationH; 213 | sdw_args.dilation_width = dilationW; 214 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 215 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 216 | sdw_args.scope_height = scopeH; 217 | sdw_args.scope_width = scopeW; 218 | sdw_args.sampling_group = groupS; 219 | 220 | gradRotation = gradRotation.view({ 221 | sdw_args.batch, 222 | groupS * kH * kW * 2, 223 | sdw_args.out_height, 224 | sdw_args.out_width}); 225 | 226 | SampleDepthwiseConv2dBackwardRotation( 227 | gradOutput, 228 | input, 229 | rotation, 230 | filter, 231 | sdw_args, 232 | gradRotation); 233 | 234 | return 1; 235 | } 236 | 237 | int deformable_sample_depthwise_forward_cuda( 238 | at::Tensor input, 239 | at::Tensor offset, 240 | at::Tensor rotation, 241 | at::Tensor filter, 242 | at::Tensor output, 243 | int kH, 244 | int kW, 245 | int dH, 246 | int dW, 247 | int padH, 248 | int padW, 249 | int dilationH, 250 | int dilationW, 251 | int scopeH, 252 | int scopeW, 253 | int groupS) { 254 | 255 | input = input.contiguous(); 256 | offset = offset.contiguous(); 257 | rotation = rotation.contiguous(); 258 | filter = filter.contiguous(); 259 | 260 | SampleDepthwiseArgs sdw_args; 261 | sdw_args.batch = input.size(0); 262 | sdw_args.channel = input.size(1); 263 | sdw_args.in_height = input.size(2); 264 | sdw_args.in_width = input.size(3); 265 | sdw_args.filter_height = kH; 266 | sdw_args.filter_width = kW; 267 | sdw_args.stride_height = dH; 268 | sdw_args.stride_width = dW; 269 | sdw_args.pad_height = padH; 270 | sdw_args.pad_width = padW; 271 | sdw_args.dilation_height = dilationH; 272 | sdw_args.dilation_width = dilationW; 273 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 274 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 275 | sdw_args.scope_height = scopeH; 276 | sdw_args.scope_width = scopeW; 277 | sdw_args.sampling_group = groupS; 278 | 279 | output = output.view({ 280 | sdw_args.batch, 281 | sdw_args.channel, 282 | sdw_args.out_height, 283 | sdw_args.out_width}); 284 | 285 | DeformableSampleDepthwiseConv2dForward( 286 | input, 287 | offset, 288 | rotation, 289 | filter, 290 | sdw_args, 291 | output); 292 | 293 | return 1; 294 | } 295 | 296 | int deformable_sample_depthwise_backward_data_cuda( 297 | at::Tensor gradOutput, 298 | at::Tensor input, 299 | at::Tensor offset, 300 | at::Tensor rotation, 301 | at::Tensor filter, 302 | at::Tensor gradInput, 303 | int kH, 304 | int kW, 305 | int dH, 306 | int dW, 307 | int padH, 308 | int padW, 309 | int dilationH, 310 | int dilationW, 311 | int scopeH, 312 | int scopeW, 313 | int groupS) { 314 | 315 | gradOutput = gradOutput.contiguous(); 316 | offset = offset.contiguous(); 317 | rotation = rotation.contiguous(); 318 | filter = filter.contiguous(); 319 | 320 | SampleDepthwiseArgs sdw_args; 321 | sdw_args.batch = input.size(0); 322 | sdw_args.channel = input.size(1); 323 | sdw_args.in_height = input.size(2); 324 | sdw_args.in_width = input.size(3); 325 | sdw_args.filter_height = kH; 326 | sdw_args.filter_width = kW; 327 | sdw_args.stride_height = dH; 328 | sdw_args.stride_width = dW; 329 | sdw_args.pad_height = padH; 330 | sdw_args.pad_width = padW; 331 | sdw_args.dilation_height = dilationH; 332 | sdw_args.dilation_width = dilationW; 333 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 334 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 335 | sdw_args.scope_height = scopeH; 336 | sdw_args.scope_width = scopeW; 337 | sdw_args.sampling_group = groupS; 338 | 339 | gradInput = gradInput.view({ 340 | sdw_args.batch, 341 | sdw_args.channel, 342 | sdw_args.in_height, 343 | sdw_args.in_width}); 344 | 345 | DeformableSampleDepthwiseConv2dBackwardData( 346 | gradOutput, 347 | offset, 348 | rotation, 349 | filter, 350 | sdw_args, 351 | gradInput); 352 | 353 | return 1; 354 | } 355 | 356 | int deformable_sample_depthwise_backward_filter_cuda( 357 | at::Tensor gradOutput, 358 | at::Tensor input, 359 | at::Tensor offset, 360 | at::Tensor rotation, 361 | at::Tensor filter, 362 | at::Tensor gradFilter, 363 | int kH, 364 | int kW, 365 | int dH, 366 | int dW, 367 | int padH, 368 | int padW, 369 | int dilationH, 370 | int dilationW, 371 | int scopeH, 372 | int scopeW, 373 | int groupS) { 374 | 375 | gradOutput = gradOutput.contiguous(); 376 | input = input.contiguous(); 377 | offset = offset.contiguous(); 378 | rotation = rotation.contiguous(); 379 | 380 | SampleDepthwiseArgs sdw_args; 381 | sdw_args.batch = input.size(0); 382 | sdw_args.channel = input.size(1); 383 | sdw_args.in_height = input.size(2); 384 | sdw_args.in_width = input.size(3); 385 | sdw_args.filter_height = kH; 386 | sdw_args.filter_width = kW; 387 | sdw_args.stride_height = dH; 388 | sdw_args.stride_width = dW; 389 | sdw_args.pad_height = padH; 390 | sdw_args.pad_width = padW; 391 | sdw_args.dilation_height = dilationH; 392 | sdw_args.dilation_width = dilationW; 393 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 394 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 395 | sdw_args.scope_height = scopeH; 396 | sdw_args.scope_width = scopeW; 397 | sdw_args.sampling_group = groupS; 398 | 399 | gradFilter = gradFilter.view({ 400 | sdw_args.channel, 401 | 1, 402 | sdw_args.scope_height, 403 | sdw_args.scope_width}); 404 | 405 | DeformableSampleDepthwiseConv2dBackwardFilter( 406 | gradOutput, 407 | input, 408 | offset, 409 | rotation, 410 | sdw_args, 411 | gradFilter); 412 | 413 | return 1; 414 | } 415 | 416 | int deformable_sample_depthwise_backward_offset_cuda( 417 | at::Tensor gradOutput, 418 | at::Tensor input, 419 | at::Tensor offset, 420 | at::Tensor rotation, 421 | at::Tensor filter, 422 | at::Tensor gradOffset, 423 | int kH, 424 | int kW, 425 | int dH, 426 | int dW, 427 | int padH, 428 | int padW, 429 | int dilationH, 430 | int dilationW, 431 | int scopeH, 432 | int scopeW, 433 | int groupS) { 434 | 435 | gradOutput = gradOutput.contiguous(); 436 | input = input.contiguous(); 437 | offset = offset.contiguous(); 438 | rotation = rotation.contiguous(); 439 | filter = filter.contiguous(); 440 | 441 | SampleDepthwiseArgs sdw_args; 442 | sdw_args.batch = input.size(0); 443 | sdw_args.channel = input.size(1); 444 | sdw_args.in_height = input.size(2); 445 | sdw_args.in_width = input.size(3); 446 | sdw_args.filter_height = kH; 447 | sdw_args.filter_width = kW; 448 | sdw_args.stride_height = dH; 449 | sdw_args.stride_width = dW; 450 | sdw_args.pad_height = padH; 451 | sdw_args.pad_width = padW; 452 | sdw_args.dilation_height = dilationH; 453 | sdw_args.dilation_width = dilationW; 454 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 455 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 456 | sdw_args.scope_height = scopeH; 457 | sdw_args.scope_width = scopeW; 458 | sdw_args.sampling_group = groupS; 459 | 460 | gradOffset = gradOffset.view({ 461 | sdw_args.batch, 462 | kH * kW * 2, 463 | sdw_args.out_height, 464 | sdw_args.out_width}); 465 | 466 | DeformableSampleDepthwiseConv2dBackwardOffset( 467 | gradOutput, 468 | input, 469 | offset, 470 | rotation, 471 | filter, 472 | sdw_args, 473 | gradOffset); 474 | 475 | return 1; 476 | } 477 | 478 | int deformable_sample_depthwise_backward_rotation_cuda( 479 | at::Tensor gradOutput, 480 | at::Tensor input, 481 | at::Tensor offset, 482 | at::Tensor rotation, 483 | at::Tensor filter, 484 | at::Tensor gradRotation, 485 | int kH, 486 | int kW, 487 | int dH, 488 | int dW, 489 | int padH, 490 | int padW, 491 | int dilationH, 492 | int dilationW, 493 | int scopeH, 494 | int scopeW, 495 | int groupS) { 496 | 497 | gradOutput = gradOutput.contiguous(); 498 | input = input.contiguous(); 499 | offset = offset.contiguous(); 500 | rotation = rotation.contiguous(); 501 | filter = filter.contiguous(); 502 | 503 | SampleDepthwiseArgs sdw_args; 504 | sdw_args.batch = input.size(0); 505 | sdw_args.channel = input.size(1); 506 | sdw_args.in_height = input.size(2); 507 | sdw_args.in_width = input.size(3); 508 | sdw_args.filter_height = kH; 509 | sdw_args.filter_width = kW; 510 | sdw_args.stride_height = dH; 511 | sdw_args.stride_width = dW; 512 | sdw_args.pad_height = padH; 513 | sdw_args.pad_width = padW; 514 | sdw_args.dilation_height = dilationH; 515 | sdw_args.dilation_width = dilationW; 516 | sdw_args.out_height = (sdw_args.in_height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 517 | sdw_args.out_width = (sdw_args.in_width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 518 | sdw_args.scope_height = scopeH; 519 | sdw_args.scope_width = scopeW; 520 | sdw_args.sampling_group = groupS; 521 | 522 | gradRotation = gradRotation.view({ 523 | sdw_args.batch, 524 | groupS * kH * kW * 2, 525 | sdw_args.out_height, 526 | sdw_args.out_width}); 527 | 528 | DeformableSampleDepthwiseConv2dBackwardRotation( 529 | gradOutput, 530 | input, 531 | offset, 532 | rotation, 533 | filter, 534 | sdw_args, 535 | gradRotation); 536 | 537 | return 1; 538 | } 539 | 540 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 541 | m.def("sample_depthwise_forward_cuda", 542 | &sample_depthwise_forward_cuda, 543 | "sample_depthwise_forward (CUDA)"); 544 | m.def("sample_depthwise_backward_data_cuda", 545 | &sample_depthwise_backward_data_cuda, 546 | "sample_depthwise_backward_data (CUDA)"); 547 | m.def("sample_depthwise_backward_filter_cuda", 548 | &sample_depthwise_backward_filter_cuda, 549 | "sample_depthwise_backward_filter (CUDA)"); 550 | m.def("sample_depthwise_backward_rotation_cuda", 551 | &sample_depthwise_backward_rotation_cuda, 552 | "sample_depthwise_backward_rotation (CUDA)"); 553 | 554 | m.def("deformable_sample_depthwise_forward_cuda", 555 | &deformable_sample_depthwise_forward_cuda, 556 | "deformable_sample_depthwise_forward (CUDA)"); 557 | m.def("deformable_sample_depthwise_backward_data_cuda", 558 | &deformable_sample_depthwise_backward_data_cuda, 559 | "deformable_sample_depthwise_backward_data (CUDA)"); 560 | m.def("deformable_sample_depthwise_backward_filter_cuda", 561 | &deformable_sample_depthwise_backward_filter_cuda, 562 | "deformable_sample_depthwise_backward_filter (CUDA)"); 563 | m.def("deformable_sample_depthwise_backward_offset_cuda", 564 | &deformable_sample_depthwise_backward_offset_cuda, 565 | "deformable_sample_depthwise_backward_offset (CUDA)"); 566 | m.def("deformable_sample_depthwise_backward_rotation_cuda", 567 | &deformable_sample_depthwise_backward_rotation_cuda, 568 | "deformable_sample_depthwise_backward_rotation (CUDA)"); 569 | } 570 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda.h: -------------------------------------------------------------------------------- 1 | struct DepthwiseArgs { 2 | // Input layer dimensions 3 | int batch; 4 | int channel; 5 | int in_height; 6 | int in_width; 7 | 8 | // Weight layer dimensions 9 | int filter_height; 10 | int filter_width; 11 | int stride_height; 12 | int stride_width; 13 | int pad_height; 14 | int pad_width; 15 | int dilation_height; 16 | int dilation_width; 17 | 18 | // Output layer dimensions 19 | int out_height; 20 | int out_width; 21 | }; 22 | 23 | struct SampleDepthwiseArgs : DepthwiseArgs { 24 | // Weight layer dimensions 25 | int scope_height; 26 | int scope_width; 27 | int sampling_group; 28 | }; 29 | 30 | void SampleDepthwiseConv2dForward( 31 | const at::Tensor input, 32 | const at::Tensor rotation_ratio, 33 | const at::Tensor filter, 34 | const SampleDepthwiseArgs args, 35 | at::Tensor output); 36 | 37 | void DeformableSampleDepthwiseConv2dForward( 38 | const at::Tensor input, 39 | const at::Tensor offset, 40 | const at::Tensor rotation_ratio, 41 | const at::Tensor filter, 42 | const SampleDepthwiseArgs args, 43 | at::Tensor output); 44 | 45 | void SampleDepthwiseConv2dBackwardData( 46 | const at::Tensor out_grad, 47 | const at::Tensor rotation_ratio, 48 | const at::Tensor filter, 49 | const SampleDepthwiseArgs args, 50 | at::Tensor in_grad); 51 | 52 | void DeformableSampleDepthwiseConv2dBackwardData( 53 | const at::Tensor out_grad, 54 | const at::Tensor offset, 55 | const at::Tensor rotation_ratio, 56 | const at::Tensor filter, 57 | const SampleDepthwiseArgs args, 58 | at::Tensor in_grad); 59 | 60 | void SampleDepthwiseConv2dBackwardFilter( 61 | const at::Tensor out_grad, 62 | const at::Tensor input, 63 | const at::Tensor rotation_ratio, 64 | const SampleDepthwiseArgs args, 65 | at::Tensor filter_grad); 66 | 67 | void DeformableSampleDepthwiseConv2dBackwardFilter( 68 | const at::Tensor out_grad, 69 | const at::Tensor input, 70 | const at::Tensor offset, 71 | const at::Tensor rotation_ratio, 72 | const SampleDepthwiseArgs args, 73 | at::Tensor filter_grad); 74 | 75 | void DeformableSampleDepthwiseConv2dBackwardOffset( 76 | const at::Tensor out_grad, 77 | const at::Tensor input, 78 | const at::Tensor offset, 79 | const at::Tensor rotation_ratio, 80 | const at::Tensor filter, 81 | const SampleDepthwiseArgs args, 82 | at::Tensor offset_grad); 83 | 84 | void SampleDepthwiseConv2dBackwardRotation( 85 | const at::Tensor out_grad, 86 | const at::Tensor input, 87 | const at::Tensor rotation_ratio, 88 | const at::Tensor filter, 89 | const SampleDepthwiseArgs args, 90 | at::Tensor rotation_grad); 91 | 92 | void DeformableSampleDepthwiseConv2dBackwardRotation( 93 | const at::Tensor out_grad, 94 | const at::Tensor input, 95 | const at::Tensor offset, 96 | const at::Tensor rotation_ratio, 97 | const at::Tensor filter, 98 | const SampleDepthwiseArgs args, 99 | at::Tensor rotation_grad); 100 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/filter_sample_depthwise_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "filter_sample_depthwise_cuda.h" 9 | 10 | using namespace at; 11 | 12 | #define CUDA_KERNEL_LOOP(i, n) \ 13 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | 16 | const int CUDA_NUM_THREADS = 1024; 17 | const int kMaxGridNum = 65535; 18 | 19 | inline int GET_BLOCKS(const int N) { 20 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); 21 | } 22 | 23 | #if !defined(_MSC_VER) 24 | #define CUDA_UNROLL _Pragma("unroll") 25 | #define CUDA_NOUNROLL _Pragma("nounroll") 26 | #else 27 | #define CUDA_UNROLL 28 | #define CUDA_NOUNROLL 29 | #endif 30 | 31 | template 32 | __device__ inline scalar_t ldg(const scalar_t* address) { 33 | #if __CUDA_ARCH__ >= 350 34 | return __ldg(address); 35 | #else 36 | return *address; 37 | #endif 38 | } 39 | 40 | template 41 | inline scalar_t __device__ CudaMax(scalar_t a, scalar_t b) { 42 | return a > b ? a : b; 43 | } 44 | 45 | template 46 | inline scalar_t __device__ CudaMin(scalar_t a, scalar_t b) { 47 | return a < b ? a : b; 48 | } 49 | 50 | // assuming h, w is remainder of division, thus h in [0, height), w in [0, width) 51 | template 52 | __device__ scalar_t planar_bilinear( 53 | const scalar_t *data, 54 | const int height, 55 | const int width, 56 | const scalar_t h, 57 | const scalar_t w) { 58 | 59 | if (h > -1 && w > -1 && h < height && w < width) { 60 | int h_low = floor(h); 61 | int w_low = floor(w); 62 | 63 | int h_high = h_low + 1; 64 | int w_high = w_low + 1; 65 | 66 | const scalar_t lh = h - h_low; 67 | const scalar_t lw = w - w_low; 68 | const scalar_t hh = 1 - lh, hw = 1 - lw; 69 | 70 | scalar_t val = 0; 71 | if (h_low >= 0 && w_low >= 0) 72 | val += hh * hw * ldg(data + h_low * width + w_low); 73 | if (h_low >=0 && w_high <= width - 1) 74 | val += hh * lw * ldg(data + h_low * width + w_high); 75 | if (h_high <= height - 1 && w_low >= 0) 76 | val += lh * hw * ldg(data + h_high * width + w_low); 77 | if (h_high <= height - 1 && w_high <= width - 1) 78 | val += lh * lw * ldg(data + h_high * width + w_high); 79 | return val; 80 | } else { 81 | return 0; 82 | } 83 | } 84 | 85 | template 86 | __device__ void planar_bilinear_backward_data( 87 | const scalar_t partial_sum, 88 | const int height, 89 | const int width, 90 | const scalar_t h, 91 | const scalar_t w, 92 | scalar_t* filter_gradient) { 93 | 94 | if (h > -1 && w > -1 && h < height && w < width) { 95 | int h_low = floor(h); 96 | int w_low = floor(w); 97 | 98 | int h_high = h_low + 1; 99 | int w_high = w_low + 1; 100 | 101 | const scalar_t lh = h - h_low; 102 | const scalar_t lw = w - w_low; 103 | const scalar_t hh = 1 - lh, hw = 1 - lw; 104 | 105 | if (h_low >= 0 && w_low >= 0) 106 | atomicAdd(filter_gradient + h_low * width + w_low, hh * hw * partial_sum); 107 | if (h_low >=0 && w_high <= width - 1) 108 | atomicAdd(filter_gradient + h_low * width + w_high, hh * lw * partial_sum); 109 | if (h_high <= height - 1 && w_low >= 0) 110 | atomicAdd(filter_gradient + h_high * width + w_low, lh * hw * partial_sum); 111 | if (h_high <= height - 1 && w_high <= width - 1) 112 | atomicAdd(filter_gradient + h_high * width + w_high, lh * lw * partial_sum); 113 | } 114 | } 115 | 116 | template 117 | __device__ scalar_t planar_bilinear_backward_coord( 118 | const scalar_t partial_sum, 119 | const scalar_t* filter, 120 | const int height, 121 | const int width, 122 | const scalar_t h, 123 | const scalar_t w, 124 | const int bp_dir) { 125 | 126 | if (h > -1 && w > -1 && h < height && w < width) { 127 | int h_low = floor(h); 128 | int w_low = floor(w); 129 | 130 | int h_high = h_low + 1; 131 | int w_high = w_low + 1; 132 | 133 | const scalar_t lh = h - h_low; 134 | const scalar_t lw = w - w_low; 135 | const scalar_t hh = 1 - lh, hw = 1 - lw; 136 | 137 | if (bp_dir == 0) { 138 | scalar_t gradient_h = 0; 139 | if (h_low >= 0 && w_low >= 0) 140 | gradient_h -= hw * partial_sum * ldg(filter + h_low * width + w_low); 141 | if (h_low >=0 && w_high <= width - 1) 142 | gradient_h -= lw * partial_sum * ldg(filter + h_low * width + w_high); 143 | if (h_high <= height - 1 && w_low >= 0) 144 | gradient_h += hw * partial_sum * ldg(filter + h_high * width + w_low); 145 | if (h_high <= height - 1 && w_high <= width - 1) 146 | gradient_h += lw * partial_sum * ldg(filter + h_high * width + w_high); 147 | return gradient_h; 148 | } else { 149 | scalar_t gradient_w = 0; 150 | if (h_low >= 0 && w_low >= 0) 151 | gradient_w -= hh * partial_sum * ldg(filter + h_low * width + w_low); 152 | if (h_low >=0 && w_high <= width - 1) 153 | gradient_w += hh * partial_sum * ldg(filter + h_low * width + w_high); 154 | if (h_high <= height - 1 && w_low >= 0) 155 | gradient_w -= lh * partial_sum * ldg(filter + h_high * width + w_low); 156 | if (h_high <= height - 1 && w_high <= width - 1) 157 | gradient_w += lh * partial_sum * ldg(filter + h_high * width + w_high); 158 | return gradient_w; 159 | } 160 | } else { 161 | return 0; 162 | } 163 | } 164 | 165 | template 166 | __device__ scalar_t deformable_im2col_bilinear( 167 | const scalar_t *bottom_data, 168 | const int height, 169 | const int width, 170 | scalar_t h, 171 | scalar_t w) { 172 | 173 | int h_low = floor(h); 174 | int w_low = floor(w); 175 | int h_high = h_low + 1; 176 | int w_high = w_low + 1; 177 | 178 | scalar_t lh = h - h_low; 179 | scalar_t lw = w - w_low; 180 | scalar_t hh = 1 - lh, hw = 1 - lw; 181 | 182 | scalar_t v1 = 0; 183 | if (h_low >= 0 && w_low >= 0) 184 | v1 = ldg(bottom_data + h_low * width + w_low); 185 | scalar_t v2 = 0; 186 | if (h_low >=0 && w_high <= width - 1) 187 | v2 = ldg(bottom_data + h_low * width + w_high); 188 | scalar_t v3 = 0; 189 | if (h_high <= height - 1 && w_low >= 0) 190 | v3 = ldg(bottom_data + h_high * width + w_low); 191 | scalar_t v4 = 0; 192 | if (h_high <= height - 1 && w_high <= width - 1) 193 | v4 = ldg(bottom_data + h_high * width + w_high); 194 | 195 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 196 | 197 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 198 | return val; 199 | } 200 | 201 | template 202 | __device__ void deformable_im2col_bilinear_backward( 203 | const scalar_t partial_sum, 204 | const scalar_t h, 205 | const scalar_t w, 206 | const int height, 207 | const int width, 208 | scalar_t* data_gradient) { 209 | 210 | int h_low = floor(h); 211 | int w_low = floor(w); 212 | int h_high = h_low + 1; 213 | int w_high = w_low + 1; 214 | 215 | scalar_t lh = h - h_low; 216 | scalar_t lw = w - w_low; 217 | scalar_t hh = 1 - lh, hw = 1 - lw; 218 | 219 | if (h_low >= 0 && w_low >= 0) 220 | atomicAdd(data_gradient + h_low * width + w_low, hh * hw * partial_sum); 221 | if (h_low >=0 && w_high <= width - 1) 222 | atomicAdd(data_gradient + h_low * width + w_high, hh * lw * partial_sum); 223 | if (h_high <= height - 1 && w_low >= 0) 224 | atomicAdd(data_gradient + h_high * width + w_low, lh * hw * partial_sum); 225 | if (h_high <= height - 1 && w_high <= width - 1) 226 | atomicAdd(data_gradient + h_high * width + w_high, lh * lw * partial_sum); 227 | 228 | return; 229 | } 230 | 231 | template 232 | __device__ scalar_t get_coordinate_weight( 233 | scalar_t argmax_h, 234 | scalar_t argmax_w, 235 | const int height, 236 | const int width, 237 | const scalar_t *im_data, 238 | const int bp_dir) { 239 | 240 | int argmax_h_low = floor(argmax_h); 241 | int argmax_w_low = floor(argmax_w); 242 | int argmax_h_high = argmax_h_low + 1; 243 | int argmax_w_high = argmax_w_low + 1; 244 | 245 | scalar_t weight = 0; 246 | 247 | if (bp_dir == 0) { 248 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 249 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * width + argmax_w_low]; 250 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 251 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * width + argmax_w_high]; 252 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 253 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * width + argmax_w_low]; 254 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 255 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * width + argmax_w_high]; 256 | } else if (bp_dir == 1) { 257 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 258 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * width + argmax_w_low]; 259 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 260 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * width + argmax_w_high]; 261 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 262 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * width + argmax_w_low]; 263 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 264 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * width + argmax_w_high]; 265 | } 266 | 267 | return weight; 268 | } 269 | 270 | template 271 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dForwardKernel( 272 | int n, 273 | const scalar_t* input, 274 | const scalar_t* rotation_ratio, 275 | const scalar_t* filter, 276 | const SampleDepthwiseArgs args, 277 | scalar_t* output) { 278 | 279 | const int channel = args.channel; 280 | const int in_height = args.in_height; 281 | const int in_width = args.in_width; 282 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; 283 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; 284 | const int stride_height = args.stride_height; 285 | const int stride_width = args.stride_width; 286 | const int pad_height = args.pad_height; 287 | const int pad_width = args.pad_width; 288 | const int dilation_height = args.dilation_height; 289 | const int dilation_width = args.dilation_width; 290 | const int out_height = args.out_height; 291 | const int out_width = args.out_width; 292 | 293 | const int scope_height = args.scope_height; 294 | const int scope_width = args.scope_width; 295 | const int sampling_group = args.sampling_group; 296 | 297 | CUDA_KERNEL_LOOP(thread_id, n) { 298 | const int out_w = thread_id % out_width; 299 | const int out_h = (thread_id / out_width) % out_height; 300 | const int out_c = (thread_id / out_width / out_height) % channel; 301 | const int out_b = thread_id / out_width / out_height / channel; 302 | const int in_c = out_c; 303 | 304 | const int input_offset_temp = 305 | (out_b * channel + in_c) * (in_height * in_width); 306 | 307 | const int group_id = in_c % sampling_group; 308 | const int rotation_offset_temp = 309 | (out_b * sampling_group + group_id) * (filter_height * filter_width * 2) * 310 | out_height * out_width + (out_h * out_width + out_w); 311 | const int filter_offset_temp = in_c * scope_height * scope_width; 312 | 313 | // Finally, we can iterate over the spatial dimensions and perform the 314 | // convolution, writing into the output at the end. 315 | const int input_h_start = out_h * stride_height - pad_height; 316 | const int input_w_start = out_w * stride_width - pad_width; 317 | const int input_h_end = input_h_start + (filter_height - 1) * dilation_height; 318 | const int input_w_end = input_w_start + (filter_width - 1) * dilation_width; 319 | 320 | scalar_t sum = 0; 321 | if (input_h_start >= 0 && input_w_start >= 0 && 322 | input_h_end < in_height && input_w_end < in_width) { 323 | // Loop that doesn't need to check for boundary conditions. 324 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 325 | const int in_h = input_h_start + f_h * dilation_height; 326 | 327 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 328 | const int in_w = input_w_start + f_w * dilation_width; 329 | const int input_offset = (input_offset_temp) + (in_h * in_width) + in_w; 330 | 331 | const int rotation_offset_fhw = rotation_offset_temp + 332 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 333 | const scalar_t rotation_ratio_h = 334 | ldg(rotation_ratio + rotation_offset_fhw); 335 | const scalar_t rotation_ratio_w = 336 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width); 337 | 338 | const scalar_t filter_h = f_h + rotation_ratio_h + 339 | (scope_height - filter_height) / 2.0; 340 | const scalar_t filter_w = f_w + rotation_ratio_w + 341 | (scope_width - filter_width) / 2.0; 342 | sum += ldg(input + input_offset) * planar_bilinear( 343 | filter + filter_offset_temp, 344 | scope_height, 345 | scope_width, 346 | filter_h, 347 | filter_w); 348 | } 349 | } 350 | } else { 351 | // Loop that needs to check for boundary conditions. 352 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 353 | const int in_h = input_h_start + f_h * dilation_height; 354 | 355 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 356 | const int in_w = input_w_start + f_w * dilation_width; 357 | 358 | // NOTE(Hang Gao @ 07/25): how much runtime will it save? 359 | if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) { 360 | const int input_offset = input_offset_temp + (in_h * in_width) + in_w; 361 | 362 | const int rotation_offset_fhw = rotation_offset_temp + 363 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 364 | const scalar_t rotation_ratio_h = 365 | ldg(rotation_ratio + rotation_offset_fhw); 366 | const scalar_t rotation_ratio_w = 367 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width); 368 | 369 | const scalar_t filter_h = f_h + rotation_ratio_h + 370 | (scope_height - filter_height) / 2.0; 371 | const scalar_t filter_w = f_w + rotation_ratio_w + 372 | (scope_width - filter_width) / 2.0; 373 | sum += ldg(input + input_offset) * planar_bilinear( 374 | filter + filter_offset_temp, 375 | scope_height, 376 | scope_width, 377 | filter_h, 378 | filter_w); 379 | } 380 | } 381 | 382 | } 383 | } 384 | output[thread_id] = sum; 385 | } 386 | } 387 | 388 | void SampleDepthwiseConv2dForward( 389 | const at::Tensor input, 390 | const at::Tensor rotation_ratio, 391 | const at::Tensor filter, 392 | const SampleDepthwiseArgs args, 393 | at::Tensor output) { 394 | 395 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width; 396 | 397 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 398 | input.type(), "SampleDepthwiseConv2dForward_GPU", ([&] { 399 | const scalar_t *input_ = input.data(); 400 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 401 | const scalar_t *filter_ = filter.data(); 402 | scalar_t *output_ = output.data(); 403 | 404 | 405 | if (args.filter_height == 3 && args.filter_width == 3) { 406 | SampleDepthwiseConv2dForwardKernel 407 | <<>>( 408 | num_kernels, 409 | input_, 410 | rotation_ratio_, 411 | filter_, 412 | args, 413 | output_); 414 | } else { 415 | SampleDepthwiseConv2dForwardKernel 416 | <<>>( 417 | num_kernels, 418 | input_, 419 | rotation_ratio_, 420 | filter_, 421 | args, 422 | output_); 423 | } 424 | 425 | })); 426 | 427 | cudaError_t err = cudaGetLastError(); 428 | if (err != cudaSuccess) { 429 | printf("error in SampleDepthwiseConv2dForwardKernel: %s\n", cudaGetErrorString(err)); 430 | } 431 | } 432 | 433 | template 434 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dForwardKernel( 435 | int n, 436 | const scalar_t* input, 437 | const scalar_t* offset, 438 | const scalar_t* rotation_ratio, 439 | const scalar_t* filter, 440 | const SampleDepthwiseArgs args, 441 | scalar_t* output) { 442 | 443 | const int channel = args.channel; 444 | const int in_height = args.in_height; 445 | const int in_width = args.in_width; 446 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; 447 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; 448 | const int stride_height = args.stride_height; 449 | const int stride_width = args.stride_width; 450 | const int pad_height = args.pad_height; 451 | const int pad_width = args.pad_width; 452 | const int dilation_height = args.dilation_height; 453 | const int dilation_width = args.dilation_width; 454 | const int out_height = args.out_height; 455 | const int out_width = args.out_width; 456 | 457 | const int scope_height = args.scope_height; 458 | const int scope_width = args.scope_width; 459 | const int sampling_group = args.sampling_group; 460 | 461 | CUDA_KERNEL_LOOP(thread_id, n) { 462 | const int out_w = thread_id % out_width; 463 | const int out_h = (thread_id / out_width) % out_height; 464 | const int out_c = (thread_id / out_width / out_height) % channel; 465 | const int out_b = thread_id / out_width / out_height / channel; 466 | const int in_c = out_c; 467 | 468 | const int input_offset_temp = 469 | (out_b * channel + in_c) * (in_height * in_width); 470 | const int deformation_offset_temp = 471 | out_b * (filter_height * filter_width * 2) * out_height * out_width + 472 | (out_h * out_width + out_w); 473 | const int group_id = in_c % sampling_group; 474 | const int rotation_offset_temp = (out_b * sampling_group + group_id) * 475 | (filter_height * filter_width * 2) * out_height * out_width + 476 | (out_h * out_width + out_w); 477 | const int filter_offset_temp = in_c * scope_height * scope_width; 478 | 479 | // Finally, we can iterate over the spatial dimensions and perform the 480 | // convolution, writing into the output at the end. 481 | const int input_h_start = out_h * stride_height - pad_height; 482 | const int input_w_start = out_w * stride_width - pad_width; 483 | 484 | scalar_t sum = 0; 485 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 486 | const int in_h = input_h_start + f_h * dilation_height; 487 | 488 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 489 | const int in_w = input_w_start + f_w * dilation_width; 490 | const int deformation_offset_fhw = deformation_offset_temp + 491 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 492 | const int rotation_offset_fhw = rotation_offset_temp + 493 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 494 | 495 | const scalar_t input_h = in_h + 496 | ldg(offset + deformation_offset_fhw); 497 | const scalar_t input_w = in_w + 498 | ldg(offset + deformation_offset_fhw + out_height * out_width); 499 | 500 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) { 501 | const scalar_t rotation_ratio_h = 502 | ldg(rotation_ratio + rotation_offset_fhw); 503 | const scalar_t rotation_ratio_w = 504 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width); 505 | 506 | const scalar_t cur_input = deformable_im2col_bilinear( 507 | input + input_offset_temp, 508 | in_height, 509 | in_width, 510 | input_h, 511 | input_w); 512 | 513 | const scalar_t filter_h = f_h + rotation_ratio_h + 514 | (scope_height - filter_height) / 2.0; 515 | const scalar_t filter_w = f_w + rotation_ratio_w + 516 | (scope_width - filter_width) / 2.0; 517 | sum += cur_input * planar_bilinear( 518 | filter + filter_offset_temp, 519 | scope_height, 520 | scope_width, 521 | filter_h, 522 | filter_w); 523 | } 524 | } 525 | } 526 | output[thread_id] = sum; 527 | } 528 | } 529 | 530 | void DeformableSampleDepthwiseConv2dForward( 531 | const at::Tensor input, 532 | const at::Tensor offset, 533 | const at::Tensor rotation_ratio, 534 | const at::Tensor filter, 535 | const SampleDepthwiseArgs args, 536 | at::Tensor output) { 537 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width; 538 | 539 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 540 | input.type(), "DeformableSampleDepthwiseConv2dForward_GPU", ([&] { 541 | const scalar_t *input_ = input.data(); 542 | const scalar_t *offset_ = offset.data(); 543 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 544 | const scalar_t *filter_ = filter.data(); 545 | scalar_t *output_ = output.data(); 546 | 547 | if (args.filter_height == 3 && args.filter_width == 3) { 548 | DeformableSampleDepthwiseConv2dForwardKernel 549 | <<>>( 550 | num_kernels, 551 | input_, 552 | offset_, 553 | rotation_ratio_, 554 | filter_, 555 | args, 556 | output_); 557 | } else { 558 | DeformableSampleDepthwiseConv2dForwardKernel 559 | <<>>( 560 | num_kernels, 561 | input_, 562 | offset_, 563 | rotation_ratio_, 564 | filter_, 565 | args, 566 | output_); 567 | } 568 | 569 | })); 570 | 571 | cudaError_t err = cudaGetLastError(); 572 | if (err != cudaSuccess) { 573 | printf("error in DeformableSampleDepthwiseConv2dForwardKernel: %s\n", cudaGetErrorString(err)); 574 | } 575 | } 576 | 577 | template 578 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardDataKernel( 579 | int n, 580 | const scalar_t* out_grad, 581 | const scalar_t* rotation_ratio, 582 | const scalar_t* filter, 583 | const SampleDepthwiseArgs args, 584 | scalar_t* in_grad) { 585 | 586 | const int channel = args.channel; 587 | const int in_height = args.in_height; 588 | const int in_width = args.in_width; 589 | const int filter_height = args.filter_height; 590 | const int filter_width = args.filter_width; 591 | const int stride_height = args.stride_height; 592 | const int stride_width = args.stride_width; 593 | const int pad_height = args.pad_height; 594 | const int pad_width = args.pad_width; 595 | const int dilation_height = args.dilation_height; 596 | const int dilation_width = args.dilation_width; 597 | const int out_height = args.out_height; 598 | const int out_width = args.out_width; 599 | 600 | const int scope_height = args.scope_height; 601 | const int scope_width = args.scope_width; 602 | const int sampling_group = args.sampling_group; 603 | 604 | CUDA_KERNEL_LOOP(thread_id, n) { 605 | // Compute the indexes of this thread in the input. 606 | const int in_w = thread_id % in_width; 607 | const int in_h = (thread_id / in_width) % in_height; 608 | const int channel_idx = (thread_id / in_width / in_height) % channel; 609 | const int batch_idx = thread_id / channel / in_width / in_height; 610 | 611 | const int out_h_start = CudaMax( 612 | 0, (in_h + pad_height - (filter_height - 1) * dilation_height + 613 | stride_height - 1) / stride_height); 614 | const int out_h_end = CudaMin( 615 | out_height - 1, (in_h + pad_height) / stride_height); 616 | const int out_w_start = CudaMax( 617 | 0, (in_w + pad_width - (filter_width - 1) * dilation_width + 618 | stride_width - 1) / stride_width); 619 | const int out_w_end = CudaMin( 620 | out_width - 1, (in_w + pad_width) / stride_width); 621 | 622 | const int group_id = channel_idx % sampling_group; 623 | const int rotation_offset_temp = 624 | (batch_idx * sampling_group + group_id) * 625 | (filter_height * filter_width * 2) * out_height * out_width; 626 | const int filter_offset_temp = channel_idx * scope_height * scope_width; 627 | const int out_grad_offset_temp = 628 | (batch_idx * channel + channel_idx) * (out_height * out_width); 629 | 630 | scalar_t sum = 0.0f; 631 | for (int out_h = out_h_start; out_h <= out_h_end; ++out_h) { 632 | int f_h = in_h + pad_height - out_h * stride_height; 633 | 634 | if (f_h % dilation_height == 0) { 635 | f_h /= dilation_height; 636 | const int out_grad_offset_h = out_grad_offset_temp + out_h * out_width; 637 | 638 | for (int out_w = out_w_start; out_w <= out_w_end; ++out_w) { 639 | int f_w = in_w + pad_width - out_w * stride_width; 640 | 641 | if (f_w % dilation_width == 0) { 642 | f_w /= dilation_width; 643 | const int out_grad_offset = out_grad_offset_h + out_w; 644 | 645 | const int rotation_offset_fhw = rotation_offset_temp + 646 | (f_h * filter_width + f_w) * 2 * out_height * out_width + 647 | (out_h * out_width + out_w); 648 | const scalar_t rotation_ratio_h = 649 | ldg(rotation_ratio + rotation_offset_fhw); 650 | const scalar_t rotation_ratio_w = 651 | ldg(rotation_ratio + rotation_offset_fhw + out_height * out_width); 652 | 653 | const scalar_t filter_h = f_h + rotation_ratio_h + 654 | (scope_height - filter_height) / 2.0; 655 | const scalar_t filter_w = f_w + rotation_ratio_w + 656 | (scope_width - filter_width) / 2.0; 657 | sum += ldg(out_grad + out_grad_offset) * planar_bilinear( 658 | filter + filter_offset_temp, 659 | scope_height, 660 | scope_width, 661 | filter_h, 662 | filter_w); 663 | } 664 | } 665 | } 666 | } 667 | in_grad[thread_id] = sum; 668 | } 669 | } 670 | 671 | void SampleDepthwiseConv2dBackwardData( 672 | const at::Tensor out_grad, 673 | const at::Tensor rotation_ratio, 674 | const at::Tensor filter, 675 | const SampleDepthwiseArgs args, 676 | at::Tensor in_grad) { 677 | 678 | int num_kernels = args.batch * args.channel * args.in_height * args.in_width; 679 | 680 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 681 | out_grad.type(), "SampleDepthwiseConv2dBackwardData_GPU", ([&] { 682 | const scalar_t *out_grad_ = out_grad.data(); 683 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 684 | const scalar_t *filter_ = filter.data(); 685 | scalar_t *in_grad_ = in_grad.data(); 686 | 687 | SampleDepthwiseConv2dBackwardDataKernel 688 | <<>>( 689 | num_kernels, 690 | out_grad_, 691 | rotation_ratio_, 692 | filter_, 693 | args, 694 | in_grad_); 695 | 696 | })); 697 | 698 | cudaError_t err = cudaGetLastError(); 699 | if (err != cudaSuccess) { 700 | printf("error in SampleDepthwiseConv2dBackwardDataKernel: %s\n", cudaGetErrorString(err)); 701 | } 702 | } 703 | 704 | template 705 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardDataKernel( 706 | int n, 707 | const scalar_t* out_backprop, 708 | const scalar_t* offset, 709 | const scalar_t* rotation_ratio, 710 | const scalar_t* filter, 711 | const SampleDepthwiseArgs args, 712 | scalar_t* in_grad) { 713 | 714 | const int channel = args.channel; 715 | const int in_height = args.in_height; 716 | const int in_width = args.in_width; 717 | const int filter_height = args.filter_height; 718 | const int filter_width = args.filter_width; 719 | const int stride_height = args.stride_height; 720 | const int stride_width = args.stride_width; 721 | const int pad_height = args.pad_height; 722 | const int pad_width = args.pad_width; 723 | const int dilation_height = args.dilation_height; 724 | const int dilation_width = args.dilation_width; 725 | const int out_height = args.out_height; 726 | const int out_width = args.out_width; 727 | 728 | const int scope_height = args.scope_height; 729 | const int scope_width = args.scope_width; 730 | const int sampling_group = args.sampling_group; 731 | 732 | CUDA_KERNEL_LOOP(thread_id, n) { 733 | // Compute the indexes of this thread in the output. 734 | const int out_w = thread_id % out_width; 735 | const int out_h = (thread_id / out_width) % out_height; 736 | const int in_c = (thread_id / out_width / out_height) % channel; 737 | // NOTE(Hang Gao @ 07/26): why feed data like this? -- because 738 | const int f_w = (thread_id / out_width / out_height / channel) % filter_width; 739 | const int f_h = (thread_id / out_width / out_height / channel / filter_width) % filter_height; 740 | const int out_b = (thread_id / out_width / out_height / channel / filter_width) / filter_height; 741 | 742 | // Decide if all input is valid, if yes, we can skip the boundary checks 743 | // for each input. 744 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height; 745 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width; 746 | 747 | const int deformable_offset_temp = 748 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 * 749 | out_height * out_width + (out_h * out_width + out_w); 750 | const int group_id = in_c % sampling_group; 751 | const int rotation_offset_temp = 752 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) + 753 | (f_h * filter_width + f_w)) * 2 * out_height * out_width + 754 | (out_h * out_width + out_w); 755 | const scalar_t input_h = in_row + ldg(offset + deformable_offset_temp); 756 | const scalar_t input_w = in_col + ldg( 757 | offset + deformable_offset_temp + out_height * out_width); 758 | 759 | // Avoid repeated computation. 760 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) { 761 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width); 762 | const int filter_offset_temp = in_c * scope_height * scope_width; 763 | const scalar_t out_bp = ldg( 764 | out_backprop + 765 | (out_b * channel + in_c) * (out_height * out_width) + 766 | (out_h * out_width + out_w)); 767 | 768 | const scalar_t rotation_ratio_h = ldg( 769 | rotation_ratio + rotation_offset_temp); 770 | const scalar_t rotation_ratio_w = ldg( 771 | rotation_ratio + rotation_offset_temp + out_height * out_width); 772 | 773 | scalar_t cur_weight = 0; 774 | const scalar_t filter_h = f_h + rotation_ratio_h + 775 | (scope_height - filter_height) / 2.0; 776 | const scalar_t filter_w = f_w + rotation_ratio_w + 777 | (scope_width - filter_width) / 2.0; 778 | cur_weight = planar_bilinear( 779 | filter + filter_offset_temp, 780 | scope_height, 781 | scope_width, 782 | filter_h, 783 | filter_w); 784 | 785 | const scalar_t partial_sum = cur_weight * out_bp; 786 | deformable_im2col_bilinear_backward( 787 | partial_sum, 788 | input_h, 789 | input_w, 790 | in_height, 791 | in_width, 792 | in_grad + input_offset_temp); 793 | } 794 | } 795 | } 796 | 797 | void DeformableSampleDepthwiseConv2dBackwardData( 798 | const at::Tensor out_grad, 799 | const at::Tensor offset, 800 | const at::Tensor rotation_ratio, 801 | const at::Tensor filter, 802 | const SampleDepthwiseArgs args, 803 | at::Tensor in_grad) { 804 | 805 | int num_kernels = args.batch * args.filter_height * args.filter_width * 806 | args.channel * args.out_height * args.out_width; 807 | 808 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 809 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardData_GPU", ([&] { 810 | const scalar_t *out_grad_ = out_grad.data(); 811 | const scalar_t *offset_ = offset.data(); 812 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 813 | const scalar_t *filter_ = filter.data(); 814 | scalar_t *in_grad_ = in_grad.data(); 815 | 816 | DeformableSampleDepthwiseConv2dBackwardDataKernel 817 | <<>>( 818 | num_kernels, 819 | out_grad_, 820 | offset_, 821 | rotation_ratio_, 822 | filter_, 823 | args, 824 | in_grad_); 825 | 826 | })); 827 | 828 | cudaError_t err = cudaGetLastError(); 829 | if (err != cudaSuccess) { 830 | printf( 831 | "error in DeformableSampleDepthwiseConv2dBackwardDataKernel: %s\n", 832 | cudaGetErrorString(err)); 833 | } 834 | } 835 | 836 | template 837 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardFilterKernel( 838 | int n, 839 | const scalar_t* out_backprop, 840 | const scalar_t* input, 841 | const scalar_t* rotation_ratio, 842 | const SampleDepthwiseArgs args, 843 | scalar_t* filter_backprop) { 844 | 845 | const int channel = args.channel; 846 | const int in_height = args.in_height; 847 | const int in_width = args.in_width; 848 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; 849 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; 850 | const int stride_height = args.stride_height; 851 | const int stride_width = args.stride_width; 852 | const int pad_height = args.pad_height; 853 | const int pad_width = args.pad_width; 854 | const int dilation_height = args.dilation_height; 855 | const int dilation_width = args.dilation_width; 856 | const int out_height = args.out_height; 857 | const int out_width = args.out_width; 858 | 859 | const int scope_height = args.scope_height; 860 | const int scope_width = args.scope_width; 861 | const int sampling_group = args.sampling_group; 862 | 863 | CUDA_KERNEL_LOOP(thread_id, n) { 864 | // Compute the indexes of this thread in the output. 865 | const int out_w = thread_id % out_width; 866 | const int out_h = (thread_id / out_width) % out_height; 867 | const int out_c = (thread_id / out_width / out_height) % channel; 868 | const int out_b = thread_id / out_width / out_height / channel; 869 | const int in_c = out_c; 870 | 871 | // Decide if all input is valid, if yes, we can skip the boundary checks 872 | // for each input. 873 | const int in_row_start = out_h * stride_height - pad_height; 874 | const int in_col_start = out_w * stride_width - pad_width; 875 | const int in_row_end = in_row_start + (filter_height - 1) * dilation_height; 876 | const int in_col_end = in_col_start + (filter_width - 1) * dilation_width; 877 | 878 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width); 879 | const int group_id = in_c % sampling_group; 880 | const int rotation_offset_temp = (out_b * sampling_group + group_id) * 881 | (filter_height * filter_width * 2) * out_height * out_width + 882 | (out_h * out_width + out_w); 883 | const int filter_offset_temp = in_c * scope_height * scope_width; 884 | 885 | const scalar_t out_bp = ldg(out_backprop + thread_id); 886 | if (in_row_start >= 0 && in_col_start >= 0 && 887 | in_row_end < in_height && in_col_end < in_width) { 888 | 889 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 890 | const int in_row = in_row_start + f_h * dilation_height; 891 | // Avoid repeated computation. 892 | const int input_offset_local = input_offset_temp + in_row * in_width; 893 | 894 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 895 | const int in_col = in_col_start + f_w * dilation_width; 896 | const int input_offset = input_offset_local + in_col; 897 | 898 | const int rotation_offset_fhw = 899 | rotation_offset_temp + (f_h * filter_width + f_w) * 2 * out_height * out_width; 900 | const scalar_t rotation_ratio_h = ldg( 901 | rotation_ratio + rotation_offset_fhw); 902 | const scalar_t rotation_ratio_w = ldg( 903 | rotation_ratio + rotation_offset_fhw + out_height * out_width); 904 | 905 | scalar_t partial_sum = ldg(input + input_offset) * out_bp; 906 | const scalar_t filter_h = f_h + rotation_ratio_h + 907 | (scope_height - filter_height) / 2.0; 908 | const scalar_t filter_w = f_w + rotation_ratio_w + 909 | (scope_width - filter_width) / 2.0; 910 | planar_bilinear_backward_data( 911 | partial_sum, 912 | scope_height, 913 | scope_width, 914 | filter_h, 915 | filter_w, 916 | filter_backprop + filter_offset_temp); 917 | } 918 | } 919 | } else { 920 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 921 | const int in_row = in_row_start + f_h * dilation_height; 922 | // Avoid repeated computation. 923 | const int input_offset_local = input_offset_temp + in_row * in_width; 924 | 925 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 926 | const int in_col = in_col_start + f_w * dilation_width;; 927 | 928 | if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) { 929 | const int input_offset = input_offset_local + in_col; 930 | 931 | const int rotation_offset_fhw = rotation_offset_temp + 932 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 933 | const scalar_t rotation_ratio_h = ldg( 934 | rotation_ratio + rotation_offset_fhw); 935 | const scalar_t rotation_ratio_w = ldg( 936 | rotation_ratio + rotation_offset_fhw + out_height * out_width); 937 | 938 | scalar_t partial_sum = ldg(input + input_offset) * out_bp; 939 | const scalar_t filter_h = f_h + rotation_ratio_h + 940 | (scope_height - filter_height) / 2.0; 941 | const scalar_t filter_w = f_w + rotation_ratio_w + 942 | (scope_width - filter_width) / 2.0; 943 | planar_bilinear_backward_data( 944 | partial_sum, 945 | scope_height, 946 | scope_width, 947 | filter_h, 948 | filter_w, 949 | filter_backprop + filter_offset_temp); 950 | } 951 | } 952 | } 953 | } 954 | } 955 | } 956 | 957 | void SampleDepthwiseConv2dBackwardFilter( 958 | const at::Tensor out_grad, 959 | const at::Tensor input, 960 | const at::Tensor rotation_ratio, 961 | const SampleDepthwiseArgs args, 962 | at::Tensor filter_grad) { 963 | 964 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width; 965 | 966 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 967 | out_grad.type(), "SampleDepthwiseConv2dBackwardFilter_GPU", ([&] { 968 | const scalar_t *out_grad_ = out_grad.data(); 969 | const scalar_t *input_ = input.data(); 970 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 971 | scalar_t *filter_grad_ = filter_grad.data(); 972 | 973 | if (args.filter_height == 3 && args.filter_width == 3) { 974 | SampleDepthwiseConv2dBackwardFilterKernel 975 | <<>>( 976 | num_kernels, 977 | out_grad_, 978 | input_, 979 | rotation_ratio_, 980 | args, 981 | filter_grad_); 982 | } else { 983 | SampleDepthwiseConv2dBackwardFilterKernel 984 | <<>>( 985 | num_kernels, 986 | out_grad_, 987 | input_, 988 | rotation_ratio_, 989 | args, 990 | filter_grad_); 991 | } 992 | 993 | })); 994 | 995 | cudaError_t err = cudaGetLastError(); 996 | if (err != cudaSuccess) { 997 | printf("error in SampleDepthwiseConv2dBackwardFilterKernel: %s\n", cudaGetErrorString(err)); 998 | } 999 | } 1000 | 1001 | // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. 1002 | template 1003 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardFilterKernel( 1004 | int n, 1005 | const scalar_t* out_backprop, 1006 | const scalar_t* input, 1007 | const scalar_t* offset, 1008 | const scalar_t* rotation_ratio, 1009 | const SampleDepthwiseArgs args, 1010 | scalar_t* filter_backprop) { 1011 | 1012 | const int channel = args.channel; 1013 | const int in_height = args.in_height; 1014 | const int in_width = args.in_width; 1015 | const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; 1016 | const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; 1017 | const int stride_height = args.stride_height; 1018 | const int stride_width = args.stride_width; 1019 | const int pad_height = args.pad_height; 1020 | const int pad_width = args.pad_width; 1021 | const int dilation_height = args.dilation_height; 1022 | const int dilation_width = args.dilation_width; 1023 | const int out_height = args.out_height; 1024 | const int out_width = args.out_width; 1025 | 1026 | const int scope_height = args.scope_height; 1027 | const int scope_width = args.scope_width; 1028 | const int sampling_group = args.sampling_group; 1029 | 1030 | CUDA_KERNEL_LOOP(thread_id, n) { 1031 | // Compute the indexes of this thread in the output. 1032 | const int out_w = thread_id % out_width; 1033 | const int out_h = (thread_id / out_width) % out_height; 1034 | const int out_c = (thread_id / out_width / out_height) % channel; 1035 | const int out_b = thread_id / out_width / out_height / channel; 1036 | const int in_c = out_c; 1037 | 1038 | // Decide if all input is valid, if yes, we can skip the boundary checks 1039 | // for each input. 1040 | const int in_row_start = out_h * stride_height - pad_height; 1041 | const int in_col_start = out_w * stride_width - pad_width; 1042 | 1043 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width); 1044 | const int deformation_offset_temp = out_b * (filter_height * filter_width * 2) * 1045 | out_height * out_width + (out_h * out_width + out_w); 1046 | const int group_id = in_c % sampling_group; 1047 | const int rotation_offset_temp = (out_b * sampling_group + group_id) * 1048 | (filter_height * filter_width * 2) * out_height * out_width + 1049 | (out_h * out_width + out_w); 1050 | const int filter_offset_temp = in_c * scope_height * scope_width; 1051 | 1052 | const scalar_t out_bp = ldg(out_backprop + thread_id); 1053 | 1054 | CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { 1055 | const int in_row = in_row_start + f_h * dilation_height; 1056 | 1057 | // Avoid repeated computation. 1058 | CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { 1059 | const int in_col = in_col_start + f_w * dilation_width; 1060 | 1061 | const int deformation_offset_fhw = deformation_offset_temp + 1062 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 1063 | const int rotation_offset_fhw = rotation_offset_temp + 1064 | (f_h * filter_width + f_w) * 2 * out_height * out_width; 1065 | const scalar_t input_h = in_row + ldg( 1066 | offset + deformation_offset_fhw); 1067 | const scalar_t input_w = in_col + ldg( 1068 | offset + deformation_offset_fhw + out_height * out_width); 1069 | 1070 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) { 1071 | const scalar_t rotation_ratio_h = ldg( 1072 | rotation_ratio + rotation_offset_fhw); 1073 | const scalar_t rotation_ratio_w = ldg( 1074 | rotation_ratio + rotation_offset_fhw + out_height * out_width); 1075 | 1076 | const scalar_t partial_sum = deformable_im2col_bilinear( 1077 | input + input_offset_temp, 1078 | in_height, 1079 | in_width, 1080 | input_h, 1081 | input_w) * out_bp; 1082 | const scalar_t filter_h = f_h + rotation_ratio_h + 1083 | (scope_height - filter_height) / 2.0; 1084 | const scalar_t filter_w = f_w + rotation_ratio_w + 1085 | (scope_width - filter_width) / 2.0; 1086 | planar_bilinear_backward_data( 1087 | partial_sum, 1088 | scope_height, 1089 | scope_width, 1090 | filter_h, 1091 | filter_w, 1092 | filter_backprop + filter_offset_temp); 1093 | } 1094 | } 1095 | } 1096 | } 1097 | } 1098 | 1099 | void DeformableSampleDepthwiseConv2dBackwardFilter( 1100 | const at::Tensor out_grad, 1101 | const at::Tensor input, 1102 | const at::Tensor offset, 1103 | const at::Tensor rotation_ratio, 1104 | const SampleDepthwiseArgs args, 1105 | at::Tensor filter_grad) { 1106 | 1107 | int num_kernels = args.batch * args.channel * args.out_height * args.out_width; 1108 | 1109 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 1110 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardFilter_GPU", ([&] { 1111 | const scalar_t *out_grad_ = out_grad.data(); 1112 | const scalar_t *input_ = input.data(); 1113 | const scalar_t *offset_ = offset.data(); 1114 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 1115 | scalar_t *filter_grad_ = filter_grad.data(); 1116 | 1117 | if (args.filter_height == 3 && args.filter_width == 3) { 1118 | DeformableSampleDepthwiseConv2dBackwardFilterKernel 1119 | <<>>( 1120 | num_kernels, 1121 | out_grad_, 1122 | input_, 1123 | offset_, 1124 | rotation_ratio_, 1125 | args, 1126 | filter_grad_); 1127 | } else { 1128 | DeformableSampleDepthwiseConv2dBackwardFilterKernel 1129 | <<>>( 1130 | num_kernels, 1131 | out_grad_, 1132 | input_, 1133 | offset_, 1134 | rotation_ratio_, 1135 | args, 1136 | filter_grad_); 1137 | } 1138 | 1139 | })); 1140 | 1141 | cudaError_t err = cudaGetLastError(); 1142 | if (err != cudaSuccess) { 1143 | printf("error in DeformableSampleDepthwiseConv2dBackwardFilterKernel: %s\n", cudaGetErrorString(err)); 1144 | } 1145 | } 1146 | 1147 | template 1148 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardOffsetKernel( 1149 | int n, 1150 | const scalar_t* out_backprop, 1151 | const scalar_t* input, 1152 | const scalar_t* offset, 1153 | const scalar_t* rotation_ratio, 1154 | const scalar_t* filter, 1155 | const SampleDepthwiseArgs args, 1156 | scalar_t* offset_backprop) { 1157 | 1158 | const int channel = args.channel; 1159 | const int in_height = args.in_height; 1160 | const int in_width = args.in_width; 1161 | const int filter_height = args.filter_height; 1162 | const int filter_width = args.filter_width; 1163 | const int stride_height = args.stride_height; 1164 | const int stride_width = args.stride_width; 1165 | const int pad_height = args.pad_height; 1166 | const int pad_width = args.pad_width; 1167 | const int dilation_height = args.dilation_height; 1168 | const int dilation_width = args.dilation_width; 1169 | const int out_height = args.out_height; 1170 | const int out_width = args.out_width; 1171 | 1172 | const int scope_height = args.scope_height; 1173 | const int scope_width = args.scope_width; 1174 | const int sampling_group = args.sampling_group; 1175 | 1176 | CUDA_KERNEL_LOOP(thread_id, n) { 1177 | // Compute the indexes of this thread in the output. 1178 | const int out_w = thread_id % out_width; 1179 | const int out_h = (thread_id / out_width) % out_height; 1180 | const int bp_dir = (thread_id / out_width / out_height) % 2; 1181 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width; 1182 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height; 1183 | const int out_b = (thread_id / out_width / out_height / 2 / filter_width) / filter_height; 1184 | 1185 | // Decide if all input is valid, if yes, we can skip the boundary checks 1186 | // for each input. 1187 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height; 1188 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width; 1189 | 1190 | const int deformable_offset_temp = 1191 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 * 1192 | out_height * out_width + 1193 | (out_h * out_width + out_w); 1194 | const scalar_t input_h = in_row + ldg( 1195 | offset + deformable_offset_temp); 1196 | const scalar_t input_w = in_col + ldg( 1197 | offset + deformable_offset_temp + out_height * out_width); 1198 | 1199 | scalar_t coord_gradient = 0; 1200 | // Avoid repeated computation. 1201 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) { 1202 | 1203 | for (int in_c = 0; in_c < channel; in_c++) { 1204 | const int group_id = in_c % sampling_group; 1205 | const int rotation_offset_temp = ((out_b * sampling_group + group_id) * 1206 | (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 * 1207 | out_height * out_width + (out_h * out_width + out_w); 1208 | const scalar_t rotation_ratio_h = ldg( 1209 | rotation_ratio + rotation_offset_temp); 1210 | const scalar_t rotation_ratio_w = ldg( 1211 | rotation_ratio + rotation_offset_temp + out_height * out_width); 1212 | scalar_t filter_h = f_h + rotation_ratio_h + 1213 | (scope_height - filter_height) / 2.0; 1214 | scalar_t filter_w = f_w + rotation_ratio_w + 1215 | (scope_width - filter_width) / 2.0; 1216 | 1217 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width); 1218 | const int filter_offset_temp = in_c * scope_height * scope_width; 1219 | const scalar_t out_bp = ldg( 1220 | out_backprop + 1221 | (out_b * channel + in_c) * (out_height * out_width) + 1222 | (out_h * out_width + out_w)); 1223 | 1224 | scalar_t cur_weight = planar_bilinear( 1225 | filter + filter_offset_temp, 1226 | scope_height, 1227 | scope_width, 1228 | filter_h, 1229 | filter_w); 1230 | scalar_t partial_sum = cur_weight * out_bp; 1231 | coord_gradient += get_coordinate_weight( 1232 | input_h, 1233 | input_w, 1234 | in_height, 1235 | in_width, 1236 | input + input_offset_temp, 1237 | bp_dir) * partial_sum; 1238 | } 1239 | } 1240 | 1241 | offset_backprop[thread_id] = coord_gradient; 1242 | } 1243 | } 1244 | 1245 | void DeformableSampleDepthwiseConv2dBackwardOffset( 1246 | const at::Tensor out_grad, 1247 | const at::Tensor input, 1248 | const at::Tensor offset, 1249 | const at::Tensor rotation_ratio, 1250 | const at::Tensor filter, 1251 | const SampleDepthwiseArgs args, 1252 | at::Tensor offset_grad) { 1253 | 1254 | int num_kernels = args.batch * args.filter_height * args.filter_width * 2 * 1255 | args.out_height * args.out_width; 1256 | 1257 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 1258 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardOffset_GPU", ([&] { 1259 | const scalar_t *out_grad_ = out_grad.data(); 1260 | const scalar_t *input_ = input.data(); 1261 | const scalar_t *offset_ = offset.data(); 1262 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 1263 | const scalar_t *filter_ = filter.data(); 1264 | scalar_t *offset_grad_ = offset_grad.data(); 1265 | 1266 | DeformableSampleDepthwiseConv2dBackwardOffsetKernel 1267 | <<>>( 1268 | num_kernels, 1269 | out_grad_, 1270 | input_, 1271 | offset_, 1272 | rotation_ratio_, 1273 | filter_, 1274 | args, 1275 | offset_grad_); 1276 | 1277 | })); 1278 | 1279 | cudaError_t err = cudaGetLastError(); 1280 | if (err != cudaSuccess) { 1281 | printf("error in DeformableSampleDepthwiseConv2dBackwardOffsetKernel: %s\n", cudaGetErrorString(err)); 1282 | } 1283 | } 1284 | 1285 | template 1286 | __global__ __launch_bounds__(1024, 2) void SampleDepthwiseConv2dBackwardRotationKernel( 1287 | int n, 1288 | const scalar_t* out_backprop, 1289 | const scalar_t* input, 1290 | const scalar_t* rotation_ratio, 1291 | const scalar_t* filter, 1292 | const SampleDepthwiseArgs args, 1293 | scalar_t* rotation_backprop) { 1294 | 1295 | const int channel = args.channel; 1296 | const int in_height = args.in_height; 1297 | const int in_width = args.in_width; 1298 | const int filter_height = args.filter_height; 1299 | const int filter_width = args.filter_width; 1300 | const int stride_height = args.stride_height; 1301 | const int stride_width = args.stride_width; 1302 | const int pad_height = args.pad_height; 1303 | const int pad_width = args.pad_width; 1304 | const int dilation_height = args.dilation_height; 1305 | const int dilation_width = args.dilation_width; 1306 | const int out_height = args.out_height; 1307 | const int out_width = args.out_width; 1308 | 1309 | const int scope_height = args.scope_height; 1310 | const int scope_width = args.scope_width; 1311 | const int sampling_group = args.sampling_group; 1312 | 1313 | CUDA_KERNEL_LOOP(thread_id, n) { 1314 | // Compute the indexes of this thread in the output. 1315 | const int out_w = thread_id % out_width; 1316 | const int out_h = (thread_id / out_width) % out_height; 1317 | const int bp_dir = (thread_id / out_width / out_height) % 2; 1318 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width; 1319 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height; 1320 | const int group_id = (thread_id / out_width / out_height / 2 / 1321 | filter_width / filter_height) % sampling_group; 1322 | const int out_b = (thread_id / out_width / out_height / 2 / 1323 | filter_width / filter_height) / sampling_group; 1324 | 1325 | // Decide if all input is valid, if yes, we can skip the boundary checks 1326 | // for each input. 1327 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height; 1328 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width; 1329 | 1330 | const int rotation_offset_temp = 1331 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) + 1332 | (f_h * filter_width + f_w)) * 2 * out_height * out_width + 1333 | (out_h * out_width + out_w); 1334 | const scalar_t rotation_ratio_h = ldg( 1335 | rotation_ratio + rotation_offset_temp); 1336 | const scalar_t rotation_ratio_w = ldg( 1337 | rotation_ratio + rotation_offset_temp + out_height * out_width); 1338 | scalar_t filter_h = f_h + rotation_ratio_h + 1339 | (scope_height - filter_height) / 2.0; 1340 | scalar_t filter_w = f_w + rotation_ratio_w + 1341 | (scope_width - filter_width) / 2.0; 1342 | 1343 | scalar_t coord_gradient = 0; 1344 | // Avoid repeated computation. 1345 | if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) { 1346 | for (int in_c = group_id; in_c < channel; in_c += sampling_group) { 1347 | const int input_offset_temp = 1348 | (out_b * channel + in_c) * (in_height * in_width) + 1349 | (in_row * in_width + in_col); 1350 | const int filter_offset_temp = in_c * scope_height * scope_width; 1351 | const scalar_t out_bp = ldg( 1352 | out_backprop + 1353 | (out_b * channel + in_c) * (out_height * out_width) + 1354 | (out_h * out_width + out_w)); 1355 | 1356 | scalar_t partial_sum = ldg(input + input_offset_temp) * out_bp; 1357 | coord_gradient += planar_bilinear_backward_coord( 1358 | partial_sum, 1359 | filter + filter_offset_temp, 1360 | scope_height, 1361 | scope_width, 1362 | filter_h, 1363 | filter_w, 1364 | bp_dir); 1365 | } 1366 | } 1367 | 1368 | rotation_backprop[thread_id] = coord_gradient; 1369 | } 1370 | } 1371 | 1372 | void SampleDepthwiseConv2dBackwardRotation( 1373 | const at::Tensor out_grad, 1374 | const at::Tensor input, 1375 | const at::Tensor rotation_ratio, 1376 | const at::Tensor filter, 1377 | const SampleDepthwiseArgs args, 1378 | at::Tensor rotation_grad) { 1379 | 1380 | int num_kernels = args.batch * 1381 | args.sampling_group * args.filter_height * args.filter_width * 2 * 1382 | args.out_height * args.out_width; 1383 | 1384 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 1385 | out_grad.type(), "SampleDepthwiseConv2dBackwardRotation_GPU", ([&] { 1386 | const scalar_t *out_grad_ = out_grad.data(); 1387 | const scalar_t *input_ = input.data(); 1388 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 1389 | const scalar_t *filter_ = filter.data(); 1390 | scalar_t *rotation_grad_ = rotation_grad.data(); 1391 | 1392 | SampleDepthwiseConv2dBackwardRotationKernel 1393 | <<>>( 1394 | num_kernels, 1395 | out_grad_, 1396 | input_, 1397 | rotation_ratio_, 1398 | filter_, 1399 | args, 1400 | rotation_grad_); 1401 | 1402 | })); 1403 | 1404 | cudaError_t err = cudaGetLastError(); 1405 | if (err != cudaSuccess) { 1406 | printf("error in SampleDepthwiseConv2dBackwardRotationKernel: %s\n", cudaGetErrorString(err)); 1407 | } 1408 | } 1409 | 1410 | template 1411 | __global__ __launch_bounds__(1024, 2) void DeformableSampleDepthwiseConv2dBackwardRotationKernel( 1412 | int n, 1413 | const scalar_t* out_backprop, 1414 | const scalar_t* input, 1415 | const scalar_t* offset, 1416 | const scalar_t* rotation_ratio, 1417 | const scalar_t* filter, 1418 | const SampleDepthwiseArgs args, 1419 | scalar_t* rotation_backprop) { 1420 | 1421 | const int channel = args.channel; 1422 | const int in_height = args.in_height; 1423 | const int in_width = args.in_width; 1424 | const int filter_height = args.filter_height; 1425 | const int filter_width = args.filter_width; 1426 | const int stride_height = args.stride_height; 1427 | const int stride_width = args.stride_width; 1428 | const int pad_height = args.pad_height; 1429 | const int pad_width = args.pad_width; 1430 | const int dilation_height = args.dilation_height; 1431 | const int dilation_width = args.dilation_width; 1432 | const int out_height = args.out_height; 1433 | const int out_width = args.out_width; 1434 | 1435 | const int scope_height = args.scope_height; 1436 | const int scope_width = args.scope_width; 1437 | const int sampling_group = args.sampling_group; 1438 | 1439 | CUDA_KERNEL_LOOP(thread_id, n) { 1440 | // Compute the indexes of this thread in the output. 1441 | const int out_w = thread_id % out_width; 1442 | const int out_h = (thread_id / out_width) % out_height; 1443 | const int bp_dir = (thread_id / out_width / out_height) % 2; 1444 | const int f_w = (thread_id / out_width / out_height / 2) % filter_width; 1445 | const int f_h = (thread_id / out_width / out_height / 2 / filter_width) % filter_height; 1446 | const int group_id = (thread_id / out_width / out_height / 2 / 1447 | filter_width / filter_height) % sampling_group; 1448 | const int out_b = (thread_id / out_width / out_height / 2 / 1449 | filter_width / filter_height) / sampling_group; 1450 | 1451 | // Decide if all input is valid, if yes, we can skip the boundary checks 1452 | // for each input. 1453 | const int in_row = out_h * stride_height - pad_height + f_h * dilation_height; 1454 | const int in_col = out_w * stride_width - pad_width + f_w * dilation_width; 1455 | 1456 | const int deformable_offset_temp = 1457 | (out_b * (filter_height * filter_width) + (f_h * filter_width + f_w)) * 2 * 1458 | out_height * out_width + (out_h * out_width + out_w); 1459 | const int rotation_offset_temp = 1460 | ((out_b * sampling_group + group_id) * (filter_height * filter_width) + 1461 | (f_h * filter_width + f_w)) * 2 * out_height * out_width + 1462 | (out_h * out_width + out_w); 1463 | const scalar_t input_h = in_row + ldg( 1464 | offset + deformable_offset_temp); 1465 | const scalar_t input_w = in_col + ldg( 1466 | offset + deformable_offset_temp + out_height * out_width); 1467 | 1468 | scalar_t coord_gradient = 0; 1469 | // Avoid repeated computation. 1470 | if (input_h > -1 && input_w > -1 && input_h < in_height && input_w < in_width) { 1471 | const scalar_t rotation_ratio_h = ldg( 1472 | rotation_ratio + rotation_offset_temp); 1473 | const scalar_t rotation_ratio_w = ldg( 1474 | rotation_ratio + rotation_offset_temp + out_height * out_width); 1475 | scalar_t filter_h = f_h + rotation_ratio_h + 1476 | (scope_height - filter_height) / 2.0; 1477 | scalar_t filter_w = f_w + rotation_ratio_w + 1478 | (scope_width - filter_width) / 2.0; 1479 | 1480 | for (int in_c = group_id; in_c < channel; in_c += sampling_group) { 1481 | const int input_offset_temp = (out_b * channel + in_c) * (in_height * in_width); 1482 | const int filter_offset_temp = in_c * scope_height * scope_width; 1483 | const scalar_t out_bp = ldg( 1484 | out_backprop + 1485 | (out_b * channel + in_c) * (out_height * out_width) + 1486 | (out_h * out_width + out_w)); 1487 | 1488 | scalar_t partial_sum = deformable_im2col_bilinear( 1489 | input + input_offset_temp, 1490 | in_height, 1491 | in_width, 1492 | input_h, 1493 | input_w) * out_bp; 1494 | coord_gradient += planar_bilinear_backward_coord( 1495 | partial_sum, 1496 | filter + filter_offset_temp, 1497 | scope_height, 1498 | scope_width, 1499 | filter_h, 1500 | filter_w, 1501 | bp_dir); 1502 | } 1503 | } 1504 | 1505 | rotation_backprop[thread_id] = coord_gradient; 1506 | } 1507 | } 1508 | 1509 | void DeformableSampleDepthwiseConv2dBackwardRotation( 1510 | const at::Tensor out_grad, 1511 | const at::Tensor input, 1512 | const at::Tensor offset, 1513 | const at::Tensor rotation_ratio, 1514 | const at::Tensor filter, 1515 | const SampleDepthwiseArgs args, 1516 | at::Tensor rotation_grad) { 1517 | 1518 | int num_kernels = args.batch * 1519 | args.sampling_group * args.filter_height * args.filter_width * 2 * 1520 | args.out_height * args.out_width; 1521 | 1522 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 1523 | out_grad.type(), "DeformableSampleDepthwiseConv2dBackwardRotation_GPU", ([&] { 1524 | const scalar_t *out_grad_ = out_grad.data(); 1525 | const scalar_t *input_ = input.data(); 1526 | const scalar_t *offset_ = offset.data(); 1527 | const scalar_t *rotation_ratio_ = rotation_ratio.data(); 1528 | const scalar_t *filter_ = filter.data(); 1529 | scalar_t *rotation_grad_ = rotation_grad.data(); 1530 | 1531 | DeformableSampleDepthwiseConv2dBackwardRotationKernel 1532 | <<>>( 1533 | num_kernels, 1534 | out_grad_, 1535 | input_, 1536 | offset_, 1537 | rotation_ratio_, 1538 | filter_, 1539 | args, 1540 | rotation_grad_); 1541 | 1542 | })); 1543 | 1544 | cudaError_t err = cudaGetLastError(); 1545 | if (err != cudaSuccess) { 1546 | printf("error in DeformableSampleDepthwiseConv2dBackwardRotationKernel: %s\n", cudaGetErrorString(err)); 1547 | } 1548 | } 1549 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "nd_linear_sample_cuda.h" 7 | 8 | /*------------------------------------------------------------------------------------------------------------*/ 9 | 10 | int nd_linear_sample_forward_cuda(at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor output) { 11 | data = data.contiguous(); 12 | shape = shape.contiguous(); 13 | coord = coord.contiguous(); 14 | 15 | SampleArgs args; 16 | args.batch = coord.size(0); 17 | args.channel = data.size(0); 18 | args.spatial_dims = data.dim() - 1; 19 | args.prod_shape = data.stride(0); 20 | 21 | output = output.view({args.batch, args.channel}); 22 | 23 | NdLinearSampleForward(data, shape, coord, args, output); 24 | 25 | return 1; 26 | } 27 | 28 | 29 | int nd_linear_sample_backward_data_cuda(at::Tensor out_grad, at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor in_grad) { 30 | out_grad = out_grad.contiguous(); 31 | shape = shape.contiguous(); 32 | coord = coord.contiguous(); 33 | 34 | SampleArgs args; 35 | args.batch = coord.size(0); 36 | args.channel = data.size(0); 37 | args.spatial_dims = data.dim() - 1; 38 | args.prod_shape = data.stride(0); 39 | 40 | in_grad = in_grad.view_as(data).zero_(); 41 | 42 | NdLinearSampleBackwardData(out_grad, shape, coord, args, in_grad); 43 | 44 | return 1; 45 | } 46 | 47 | int nd_linear_sample_backward_coord_cuda(at::Tensor out_grad, at::Tensor data, at::Tensor shape, at::Tensor coord, at::Tensor coord_grad_c) { 48 | out_grad = out_grad.contiguous(); 49 | data = data.contiguous(); 50 | shape = shape.contiguous(); 51 | coord = coord.contiguous(); 52 | 53 | SampleArgs args; 54 | args.batch = coord.size(0); 55 | args.channel = data.size(0); 56 | args.spatial_dims = data.dim() - 1; 57 | args.prod_shape = data.stride(0); 58 | 59 | coord_grad_c = coord_grad_c.view({args.batch, args.spatial_dims, args.channel}).zero_(); 60 | 61 | NdLinearSampleBackwardCoord(out_grad, data, shape, coord, args, coord_grad_c); 62 | 63 | return 1; 64 | } 65 | 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("nd_linear_sample_forward_cuda", &nd_linear_sample_forward_cuda, 69 | "nd_linear_sample forward (CUDA)"); 70 | m.def("nd_linear_sample_backward_data_cuda", &nd_linear_sample_backward_data_cuda, 71 | "nd_linear_sample_backward_data (CUDA)"); 72 | m.def("nd_linear_sample_backward_coord_cuda", &nd_linear_sample_backward_coord_cuda, 73 | "nd_linear_sample_backward_coord (CUDA)"); 74 | } 75 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda.h: -------------------------------------------------------------------------------- 1 | struct SampleArgs { 2 | // Input layer dimensions 3 | int batch; 4 | int channel; 5 | int spatial_dims; 6 | int prod_shape; 7 | }; 8 | 9 | void NdLinearSampleForward(const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor output); 10 | 11 | void NdLinearSampleBackwardData(const at::Tensor out_grad, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor in_grad); 12 | 13 | void NdLinearSampleBackwardCoord(const at::Tensor out_grad, const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor coord_grad_c); -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/csrc/nd_linear_sample_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "nd_linear_sample_cuda.h" 8 | 9 | using namespace at; 10 | 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 13 | i += blockDim.x * gridDim.x) 14 | 15 | const int CUDA_NUM_THREADS = 1024; 16 | const int kMaxGridNum = 65535; 17 | 18 | inline int GET_BLOCKS(const int N) { 19 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); 20 | } 21 | 22 | #if !defined(_MSC_VER) 23 | #define CUDA_UNROLL _Pragma("unroll") 24 | #define CUDA_NOUNROLL _Pragma("nounroll") 25 | #else 26 | #define CUDA_UNROLL 27 | #define CUDA_NOUNROLL 28 | #endif 29 | 30 | template 31 | __device__ inline DType ldg(const DType* address) { 32 | #if __CUDA_ARCH__ >= 350 33 | return __ldg(address); 34 | #else 35 | return *address; 36 | #endif 37 | } 38 | 39 | /*------------------------------------------------------------------------------------------------------------*/ 40 | 41 | 42 | // data: d1, d2, ..., dn 43 | // shape: s1, s2, ..., sn 44 | // coord: x1, x2, ..., xn 45 | template 46 | inline __device__ scalar_t nd_linear(const scalar_t *data, const scalar_t* shape, const scalar_t* coord, const int dims) { 47 | for (int d = 0; d < dims; d++) { 48 | scalar_t x = ldg(coord + d); 49 | int s = static_cast(ldg(shape + d)); 50 | if (x <= -1 || x >= s) { 51 | return 0; 52 | } 53 | } 54 | 55 | const uint corners = 1 << dims; 56 | scalar_t val = 0; 57 | for (uint i = 0; i < corners; i++) { 58 | int data_offset = 0; 59 | scalar_t data_weight = 1; 60 | bool out_of_scope = false; 61 | 62 | for (uint d = 0; d < dims; d++) { 63 | scalar_t x = ldg(coord + d); 64 | int s = static_cast(ldg(shape + d)); 65 | 66 | data_offset *= s; 67 | uint offset = (i >> d) & 1; 68 | int x_low = floor(x); 69 | scalar_t w = x - x_low; 70 | 71 | if (offset == 0 && x_low >= 0) { 72 | data_offset += x_low; 73 | data_weight *= 1-w; 74 | } 75 | else if (offset == 1 && x_low + 1 <= s - 1) { 76 | data_offset += x_low + 1; 77 | data_weight *= w; 78 | } 79 | else { 80 | out_of_scope = true; 81 | break; 82 | } 83 | } 84 | if (!out_of_scope) 85 | val += data_weight * ldg(data + data_offset); 86 | } 87 | return val; 88 | } 89 | 90 | template 91 | inline __device__ void nd_linear_backward_data(const scalar_t top_gradient, const scalar_t* shape, const scalar_t* coord, const int dims, scalar_t* data_gradient) { 92 | for (int d = 0; d < dims; d++) { 93 | scalar_t x = ldg(coord + d); 94 | int s = static_cast(ldg(shape + d)); 95 | if (x <= -1 || x >= s) { 96 | return; 97 | } 98 | } 99 | 100 | const uint corners = 1 << dims; 101 | for (uint i = 0; i < corners; i++) { 102 | int data_offset = 0; 103 | scalar_t data_weight = 1; 104 | bool out_of_scope = false; 105 | 106 | for (uint d = 0; d < dims; d++) { 107 | scalar_t x = ldg(coord + d); 108 | int s = static_cast(ldg(shape + d)); 109 | 110 | data_offset *= s; 111 | uint offset = (i >> d) & 1; 112 | int x_low = floor(x); 113 | scalar_t w = x - x_low; 114 | 115 | if (offset == 0 && x_low >= 0) { 116 | data_offset += x_low; 117 | data_weight *= 1-w; 118 | } 119 | else if (offset == 1 && x_low + 1 <= s - 1) { 120 | data_offset += x_low + 1; 121 | data_weight *= w; 122 | } 123 | else { 124 | out_of_scope = true; 125 | break; 126 | } 127 | } 128 | if (!out_of_scope) 129 | atomicAdd(data_gradient + data_offset, data_weight * top_gradient); 130 | } 131 | } 132 | 133 | template 134 | inline __device__ scalar_t nd_linear_backward_coord(const scalar_t top_gradient, const scalar_t *data, const scalar_t* shape, const scalar_t* coord, const int dims, const int gdim) { 135 | for (int d = 0; d < dims; d++) { 136 | scalar_t x = ldg(coord + d); 137 | int s = static_cast(ldg(shape + d)); 138 | if (x <= -1 || x >= s) { 139 | return 0; 140 | } 141 | } 142 | 143 | scalar_t grad = 0; 144 | const uint corners = 1 << dims; 145 | for (uint i = 0; i < corners; i++) { 146 | int data_offset = 0; 147 | scalar_t data_weight = 1; 148 | bool out_of_scope = false; 149 | 150 | for (uint d = 0; d < dims; d++) { 151 | scalar_t x = ldg(coord + d); 152 | int s = static_cast(ldg(shape + d)); 153 | 154 | data_offset *= s; 155 | uint offset = (i >> d) & 1; 156 | int x_low = floor(x); 157 | scalar_t w = x - x_low; 158 | 159 | if (offset == 0 && x_low >= 0) { 160 | data_offset += x_low; 161 | data_weight *= (d == gdim) ? scalar_t(-1) : 1-w; 162 | } 163 | else if (offset == 1 && x_low + 1 <= s - 1) { 164 | data_offset += x_low + 1; 165 | data_weight *= (d == gdim) ? scalar_t(1) : w; 166 | } 167 | else { 168 | out_of_scope = true; 169 | break; 170 | } 171 | } 172 | if (!out_of_scope) { 173 | grad += top_gradient * ldg(data + data_offset) * data_weight; 174 | } 175 | } 176 | return grad; 177 | } 178 | 179 | 180 | /*------------------------------------------------------------------------------------------------------------*/ 181 | 182 | 183 | template 184 | __global__ void NdLinearSampleForwardKernel(int n, 185 | const scalar_t* data, 186 | const scalar_t *shape, 187 | const scalar_t* coord, 188 | const SampleArgs args, 189 | scalar_t* output) { 190 | 191 | //const int batch = args.batch; 192 | const int channel = args.channel; 193 | const int spatial_dims = args.spatial_dims; 194 | const int prod_shape = args.prod_shape; 195 | 196 | CUDA_KERNEL_LOOP(thread_id, n) { 197 | const int out_c = thread_id % channel; 198 | const int out_n = thread_id / channel; 199 | output[thread_id] = nd_linear(data + out_c * prod_shape, shape, coord + out_n * spatial_dims, spatial_dims); 200 | } 201 | } 202 | 203 | 204 | void NdLinearSampleForward(const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor output) { 205 | int num_kernels = args.batch * args.channel; 206 | 207 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 208 | data.type(), "NdLinearSampleForward_GPU", ([&] { 209 | const scalar_t *data_ = data.data(); 210 | const scalar_t *shape_ = shape.data(); 211 | const scalar_t *coord_ = coord.data(); 212 | scalar_t *output_ = output.data(); 213 | 214 | NdLinearSampleForwardKernel<<>>( 215 | num_kernels, data_, shape_, coord_, args, output_); 216 | })); 217 | 218 | cudaError_t err = cudaGetLastError(); 219 | if (err != cudaSuccess) { 220 | printf("error in NdLinearSampleForwardKernel: %s\n", cudaGetErrorString(err)); 221 | } 222 | } 223 | 224 | 225 | template 226 | __global__ void NdLinearSampleBackwardDataKernel(int n, 227 | const scalar_t* out_grad, 228 | const scalar_t *shape, 229 | const scalar_t* coord, 230 | const SampleArgs args, 231 | scalar_t* in_grad) { 232 | 233 | //const int batch = args.batch; 234 | const int channel = args.channel; 235 | const int spatial_dims = args.spatial_dims; 236 | const int prod_shape = args.prod_shape; 237 | 238 | CUDA_KERNEL_LOOP(thread_id, n) { 239 | const int out_c = thread_id % channel; 240 | const int out_n = thread_id / channel; 241 | nd_linear_backward_data(ldg(out_grad + thread_id), shape, coord + out_n * spatial_dims, spatial_dims, in_grad + out_c * prod_shape); 242 | } 243 | } 244 | 245 | 246 | void NdLinearSampleBackwardData(const at::Tensor out_grad, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor in_grad) { 247 | int num_kernels = args.batch * args.channel; 248 | 249 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 250 | out_grad.type(), "NdLinearSampleBackwardData_GPU", ([&] { 251 | const scalar_t *out_grad_ = out_grad.data(); 252 | const scalar_t *shape_ = shape.data(); 253 | const scalar_t *coord_ = coord.data(); 254 | scalar_t *in_grad_ = in_grad.data(); 255 | 256 | NdLinearSampleBackwardDataKernel<<>>( 257 | num_kernels, out_grad_, shape_, coord_, args, in_grad_); 258 | })); 259 | 260 | cudaError_t err = cudaGetLastError(); 261 | if (err != cudaSuccess) { 262 | printf("error in NdLinearSampleBackwardDataKernel: %s\n", cudaGetErrorString(err)); 263 | } 264 | } 265 | 266 | 267 | template 268 | __global__ void NdLinearSampleBackwardCoordKernel(int n, 269 | const scalar_t* out_grad, 270 | const scalar_t* data, 271 | const scalar_t *shape, 272 | const scalar_t* coord, 273 | const SampleArgs args, 274 | scalar_t* coord_grad_c) { 275 | 276 | //const int batch = args.batch; 277 | const int channel = args.channel; 278 | const int spatial_dims = args.spatial_dims; 279 | const int prod_shape = args.prod_shape; 280 | 281 | CUDA_KERNEL_LOOP(thread_id, n) { 282 | const int out_c = thread_id % channel; 283 | const int out_d = (thread_id / channel) % spatial_dims; 284 | const int out_n = (thread_id / channel) / spatial_dims; 285 | coord_grad_c[thread_id] = nd_linear_backward_coord(ldg(out_grad + out_n * channel + out_c), data + out_c * prod_shape, shape, coord + out_n * spatial_dims, spatial_dims, out_d); 286 | } 287 | } 288 | 289 | 290 | void NdLinearSampleBackwardCoord(const at::Tensor out_grad, const at::Tensor data, const at::Tensor shape, const at::Tensor coord, const SampleArgs args, at::Tensor coord_grad_c) { 291 | int num_kernels = args.batch * args.spatial_dims * args.channel; 292 | 293 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 294 | data.type(), "NdLinearSampleBackwardCoord_GPU", ([&] { 295 | const scalar_t *out_grad_ = out_grad.data(); 296 | const scalar_t *data_ = data.data(); 297 | const scalar_t *shape_ = shape.data(); 298 | const scalar_t *coord_ = coord.data(); 299 | scalar_t *coord_grad_c_ = coord_grad_c.data(); 300 | 301 | NdLinearSampleBackwardCoordKernel<<>>( 302 | num_kernels, out_grad_, data_, shape_, coord_, args, coord_grad_c_); 303 | })); 304 | 305 | cudaError_t err = cudaGetLastError(); 306 | if (err != cudaSuccess) { 307 | printf("error in NdLinearSampleBackwardCoordKernel: %s\n", cudaGetErrorString(err)); 308 | } 309 | } 310 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/functions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | from .filter_sample_depthwise import ( 11 | sample_depthwise, 12 | deform_sample_depthwise, 13 | ) 14 | from .nd_linear_sample import nd_linear_sample 15 | 16 | __all__ = [ 17 | 'sample_depthwise', 18 | 'deform_sample_depthwise', 19 | 'nd_linear_sample', 20 | ] 21 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/functions/filter_sample_depthwise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : filter_sample_depthwise.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import torch 11 | from torch.autograd import Function 12 | from apex import amp 13 | 14 | from .. import filter_sample_depthwise_cuda 15 | 16 | __all__ = ['sample_depthwise', 'deform_sample_depthwise'] 17 | 18 | 19 | class SampleDepthwiseFunction(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, 23 | input, 24 | rotation, 25 | weight, 26 | kernel_size=(3, 3), 27 | stride=(1, 1), 28 | padding=(1, 1), 29 | dilation=(1, 1), 30 | rotation_groups=1, 31 | ): 32 | ctx.kernel_size = kernel_size 33 | ctx.stride = stride 34 | ctx.padding = padding 35 | ctx.dilation = dilation 36 | ctx.rotation_groups = rotation_groups 37 | 38 | ctx.save_for_backward(input, rotation, weight) 39 | output = input.new_empty( 40 | SampleDepthwiseFunction._output_size( 41 | input, kernel_size, stride, padding, dilation)) 42 | 43 | if not input.is_cuda: 44 | raise NotImplementedError 45 | else: 46 | filter_sample_depthwise_cuda.sample_depthwise_forward_cuda( 47 | input, rotation, weight, output, 48 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], 49 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 50 | weight.size(2), weight.size(3), rotation_groups) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, rotation, weight = ctx.saved_tensors 56 | 57 | grad_input = grad_rotation = grad_weight = None 58 | 59 | if not grad_output.is_cuda: 60 | raise NotImplementedError 61 | else: 62 | 63 | if ctx.needs_input_grad[0]: 64 | grad_input = torch.zeros_like(input) 65 | filter_sample_depthwise_cuda.sample_depthwise_backward_data_cuda( 66 | grad_output, input, rotation, weight, grad_input, 67 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], 68 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 69 | weight.size(2), weight.size(3), ctx.rotation_groups) 70 | 71 | if ctx.needs_input_grad[1]: 72 | grad_rotation = torch.zeros_like(rotation) 73 | filter_sample_depthwise_cuda.sample_depthwise_backward_rotation_cuda( 74 | grad_output, input, rotation, weight, grad_rotation, 75 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], 76 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 77 | weight.size(2), weight.size(3), ctx.rotation_groups) 78 | 79 | if ctx.needs_input_grad[2]: 80 | grad_weight = torch.zeros_like(weight) 81 | filter_sample_depthwise_cuda.sample_depthwise_backward_filter_cuda( 82 | grad_output, input, rotation, weight, grad_weight, 83 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], 84 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 85 | weight.size(2), weight.size(3), ctx.rotation_groups) 86 | 87 | return (grad_input, grad_rotation, grad_weight,) + (None,) * 4 88 | 89 | @staticmethod 90 | def _output_size(input, kernel_size, stride, padding, dilation): 91 | output_size = (input.size(0), input.size(1)) 92 | for d in range(input.dim() - 2): 93 | in_size = input.size(d + 2) 94 | pad = padding[d] 95 | kernel = dilation[d] * (kernel_size[d] - 1) + 1 96 | stride_ = stride[d] 97 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 98 | if not all(map(lambda s: s > 0, output_size)): 99 | raise ValueError( 100 | "convolution input is too small (output would be {})".format( 101 | 'x'.join(map(str, output_size)))) 102 | return output_size 103 | 104 | 105 | class DeformableSampleDepthwiseFunction(Function): 106 | @staticmethod 107 | def forward( 108 | ctx, 109 | input, 110 | offset, 111 | rotation, 112 | weight, 113 | kernel_size=(3, 3), 114 | stride=(1, 1), 115 | padding=(1, 1), 116 | dilation=(1, 1), 117 | rotation_groups=1, 118 | ): 119 | ctx.kernel_size = kernel_size 120 | ctx.stride = stride 121 | ctx.padding = padding 122 | ctx.dilation = dilation 123 | ctx.rotation_groups = rotation_groups 124 | 125 | ctx.save_for_backward(input, offset, rotation, weight) 126 | output = input.new_empty( 127 | DeformableSampleDepthwiseFunction._output_size( 128 | input, kernel_size, stride, padding, dilation)) 129 | 130 | if not input.is_cuda: 131 | raise NotImplementedError 132 | else: 133 | filter_sample_depthwise_cuda. \ 134 | deformable_sample_depthwise_forward_cuda( 135 | input, offset, rotation, weight, output, 136 | ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], 137 | ctx.stride[1], ctx.padding[0], ctx.padding[1], 138 | ctx.dilation[0], ctx.dilation[1], weight.size(2), 139 | weight.size(3), rotation_groups) 140 | return output 141 | 142 | @staticmethod 143 | def backward(ctx, grad_output): 144 | input, offset, rotation, weight = ctx.saved_tensors 145 | 146 | grad_input = grad_offset = grad_rotation = grad_weight = None 147 | 148 | if not grad_output.is_cuda: 149 | raise NotImplementedError 150 | else: 151 | if ctx.needs_input_grad[0]: 152 | grad_input = torch.zeros_like(input) 153 | filter_sample_depthwise_cuda. \ 154 | deformable_sample_depthwise_backward_data_cuda( 155 | grad_output, input, offset, rotation, weight, 156 | grad_input, ctx.kernel_size[0], ctx.kernel_size[1], 157 | ctx.stride[0], ctx.stride[1], ctx.padding[0], 158 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 159 | weight.size(2), weight.size(3), ctx.rotation_groups) 160 | 161 | if ctx.needs_input_grad[1]: 162 | grad_offset = torch.zeros_like(offset) 163 | filter_sample_depthwise_cuda. \ 164 | deformable_sample_depthwise_backward_offset_cuda( 165 | grad_output, input, offset, rotation, weight, 166 | grad_offset, ctx.kernel_size[0], ctx.kernel_size[1], 167 | ctx.stride[0], ctx.stride[1], ctx.padding[0], 168 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 169 | weight.size(2), weight.size(3), ctx.rotation_groups) 170 | 171 | if ctx.needs_input_grad[2]: 172 | grad_rotation = torch.zeros_like(rotation) 173 | filter_sample_depthwise_cuda. \ 174 | deformable_sample_depthwise_backward_rotation_cuda( 175 | grad_output, input, offset, rotation, weight, 176 | grad_rotation, ctx.kernel_size[0], ctx.kernel_size[1], 177 | ctx.stride[0], ctx.stride[1], ctx.padding[0], 178 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 179 | weight.size(2), weight.size(3), ctx.rotation_groups) 180 | 181 | if ctx.needs_input_grad[3]: 182 | grad_weight = torch.zeros_like(weight) 183 | filter_sample_depthwise_cuda. \ 184 | deformable_sample_depthwise_backward_filter_cuda( 185 | grad_output, input, offset, rotation, weight, 186 | grad_weight, ctx.kernel_size[0], ctx.kernel_size[1], 187 | ctx.stride[0], ctx.stride[1], ctx.padding[0], 188 | ctx.padding[1], ctx.dilation[0], ctx.dilation[1], 189 | weight.size(2), weight.size(3), ctx.rotation_groups) 190 | 191 | return (grad_input, grad_offset, grad_rotation, grad_weight,) + (None,) * 4 192 | 193 | @staticmethod 194 | def _output_size(input, kernel_size, stride, padding, dilation): 195 | output_size = (input.size(0), input.size(1)) 196 | for d in range(input.dim() - 2): 197 | in_size = input.size(d + 2) 198 | pad = padding[d] 199 | kernel = dilation[d] * (kernel_size[d] - 1) + 1 200 | stride_ = stride[d] 201 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 202 | if not all(map(lambda s: s > 0, output_size)): 203 | raise ValueError( 204 | "convolution input is too small (output would be {})".format( 205 | 'x'.join(map(str, output_size)))) 206 | return output_size 207 | 208 | 209 | # register as fp32 functions. 210 | sample_depthwise = amp.float_function(SampleDepthwiseFunction.apply) 211 | deform_sample_depthwise = amp.float_function(DeformableSampleDepthwiseFunction.apply) 212 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/functions/nd_linear_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : nd_linear_sample.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import torch 11 | from torch.autograd import Function 12 | 13 | from .. import nd_linear_sample_cuda 14 | 15 | __all__ = ['nd_linear_sample'] 16 | 17 | 18 | class NdLinearSampleFunction(Function): 19 | @staticmethod 20 | def forward(ctx, input, coord): 21 | ctx.save_for_backward(input, coord) 22 | shape = torch.tensor(input.shape[1:], dtype=input.dtype, device=input.device) 23 | output = input.new_empty(NdLinearSampleFunction._output_size(input, coord)) 24 | 25 | if not input.is_cuda: 26 | raise NotImplementedError 27 | else: 28 | nd_linear_sample_cuda.nd_linear_sample_forward_cuda( 29 | input, shape, coord, output 30 | ) 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, coord = ctx.saved_tensors 36 | shape = torch.tensor(input.shape[1:], dtype=input.dtype, device=input.device) 37 | grad_input = grad_coord = None 38 | 39 | if not grad_output.is_cuda: 40 | raise NotImplementedError 41 | else: 42 | if ctx.needs_input_grad[0]: 43 | grad_input = input.new_empty(*input.size()) 44 | nd_linear_sample_cuda.nd_linear_sample_backward_data_cuda( 45 | grad_output, input, shape, coord, grad_input 46 | ) 47 | 48 | if ctx.needs_input_grad[1]: 49 | grad_coord_c = coord.new_empty( 50 | coord.size(0), coord.size(1), input.size(0) 51 | ) 52 | nd_linear_sample_cuda.nd_linear_sample_backward_coord_cuda( 53 | grad_output, input, shape, coord, grad_coord_c 54 | ) 55 | grad_coord = grad_coord_c.sum(2) 56 | 57 | return (grad_input, grad_coord) 58 | 59 | @staticmethod 60 | def _output_size(input, coord): 61 | output_size = (coord.size(0), input.size(0)) 62 | return output_size 63 | 64 | 65 | nd_linear_sample = NdLinearSampleFunction.apply 66 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | from .filter_sample_depthwise import ( 11 | SampleDepthwise, 12 | DeformableSampleDepthwise, 13 | ) 14 | 15 | __all__ = [ 16 | 'SampleDepthwise', 17 | 'DeformableSampleDepthwise', 18 | ] 19 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/modules/filter_sample_depthwise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : filter_sample_depthwise.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn.modules.utils import _pair 15 | 16 | from ..functions.filter_sample_depthwise import ( 17 | sample_depthwise, 18 | deform_sample_depthwise, 19 | ) 20 | 21 | __all__ = [ 22 | 'SampleDepthwise', 23 | 'DeformableSampleDepthwise', 24 | ] 25 | 26 | 27 | class SampleDepthwise(nn.Module): 28 | def __init__(self, 29 | scope_size, 30 | in_channels, 31 | out_channels, 32 | kernel_size=3, 33 | stride=1, 34 | padding=1, 35 | dilation=1, 36 | groups=1, 37 | rotation_groups=1, 38 | bias=False): 39 | super(SampleDepthwise, self).__init__() 40 | self.in_channels = in_channels 41 | assert in_channels == out_channels and groups == in_channels 42 | assert in_channels % rotation_groups == 0 and \ 43 | out_channels % rotation_groups == 0 44 | self.scope_size = scope_size 45 | self.kernel_size = _pair(kernel_size) 46 | self.stride = _pair(stride) 47 | self.padding = _pair(padding) 48 | self.dilation = _pair(dilation) 49 | self.rotation_groups = rotation_groups 50 | self.bias = bias 51 | 52 | self.weight = nn.Parameter(torch.Tensor(self.in_channels, 1, *self.scope_size)) 53 | if not self.bias: 54 | self.bias = None 55 | else: 56 | self.bias = nn.Parameter(torch.Tensor(self.in_channels)) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | n = 1 61 | for k in self.kernel_size: 62 | n *= k 63 | stdv = 1. / math.sqrt(n) 64 | self.weight.data.uniform_(-stdv, stdv) 65 | if self.bias is not None: 66 | self.bias.data.zero_() 67 | 68 | def forward(self, input, rotation=None): 69 | if rotation is None: 70 | output_size = self._output_size(input, self.weight) 71 | rotation = input.new_zeros( 72 | input.size(0), 73 | self.rotation_groups * self.kernel_size[0] * 74 | self.kernel_size[1] * 2, 75 | output_size[2], 76 | output_size[3]) 77 | out = sample_depthwise( 78 | input, rotation, self.weight, self.kernel_size, self.stride, 79 | self.padding, self.dilation, self.rotation_groups) 80 | if self.bias is not None: 81 | out += self.bias.view(1, self.in_channels, 1, 1) 82 | return out 83 | 84 | def _output_size(self, input, weight): 85 | channels = weight.size(0) 86 | 87 | output_size = (input.size(0), channels) 88 | for d in range(input.dim() - 2): 89 | in_size = input.size(d + 2) 90 | pad = self.padding[d] 91 | kernel = self.dilation[d] * (self.kernel_size[d] - 1) + 1 92 | stride = self.stride[d] 93 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) 94 | if not all(map(lambda s: s > 0, output_size)): 95 | raise ValueError( 96 | "convolution input is too small (output would be {})".format( 97 | 'x'.join(map(str, output_size)))) 98 | return output_size 99 | 100 | def extra_repr(self): 101 | s = ('scope_size={scope_size}, in_channels={in_channels}, ' 102 | 'kernel_size={kernel_size}, stride={stride}') 103 | if self.padding != (0,) * len(self.padding): 104 | s += ', padding={padding}' 105 | if self.dilation != (1,) * len(self.dilation): 106 | s += ', dilation={dilation}' 107 | if self.rotation_groups != 1: 108 | s += ', rotation_groups={rotation_groups}' 109 | if self.bias is None: 110 | s += ', bias=False' 111 | return s.format(**self.__dict__) 112 | 113 | 114 | class DeformableSampleDepthwise(nn.Module): 115 | def __init__(self, 116 | scope_size, 117 | in_channels, 118 | out_channels, 119 | kernel_size=3, 120 | stride=1, 121 | padding=1, 122 | dilation=1, 123 | groups=1, 124 | rotation_groups=1, 125 | bias=False): 126 | super(DeformableSampleDepthwise, self).__init__() 127 | self.in_channels = in_channels 128 | assert in_channels == out_channels and groups == in_channels 129 | assert in_channels % rotation_groups == 0 and \ 130 | out_channels % rotation_groups == 0 131 | self.scope_size = scope_size 132 | self.kernel_size = _pair(kernel_size) 133 | self.stride = _pair(stride) 134 | self.padding = _pair(padding) 135 | self.dilation = _pair(dilation) 136 | self.rotation_groups = rotation_groups 137 | self.bias = bias 138 | 139 | self.weight = nn.Parameter(torch.Tensor(self.in_channels, 1, *self.scope_size)) 140 | if not self.bias: 141 | self.bias = None 142 | else: 143 | self.bias = nn.Parameter(torch.Tensor(self.in_channels)) 144 | self.reset_parameters() 145 | 146 | def reset_parameters(self): 147 | n = 1 148 | for k in self.kernel_size: 149 | n *= k 150 | stdv = 1. / math.sqrt(n) 151 | self.weight.data.uniform_(-stdv, stdv) 152 | if self.bias is not None: 153 | self.bias.data.zero_() 154 | 155 | def forward(self, input, offset=None, rotation=None): 156 | if offset is None: 157 | output_size = self._output_size(input, self.weight) 158 | offset = input.new_zeros( 159 | input.size(0), 160 | self.kernel_size[0] * self.kernel_size[1] * 2, 161 | output_size[2], 162 | output_size[3]) 163 | if rotation is None: 164 | output_size = self._output_size(input, self.weight) 165 | rotation = input.new_zeros( 166 | input.size(0), 167 | self.rotation_groups * self.kernel_size[0] * 168 | self.kernel_size[1] * 2, 169 | output_size[2], 170 | output_size[3]) 171 | 172 | out = deform_sample_depthwise( 173 | input, offset, rotation, self.weight, self.kernel_size, 174 | self.stride, self.padding, self.dilation, self.rotation_groups) 175 | if self.bias is not None: 176 | out += self.bias.view(1, self.in_channels, 1, 1) 177 | return out 178 | 179 | def _output_size(self, input, weight): 180 | channels = weight.size(0) 181 | 182 | output_size = (input.size(0), channels) 183 | for d in range(input.dim() - 2): 184 | in_size = input.size(d + 2) 185 | pad = self.padding[d] 186 | kernel = self.dilation[d] * (self.kernel_size[d] - 1) + 1 187 | stride = self.stride[d] 188 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) 189 | if not all(map(lambda s: s > 0, output_size)): 190 | raise ValueError( 191 | "convolution input is too small (output would be {})".format( 192 | 'x'.join(map(str, output_size)))) 193 | return output_size 194 | 195 | def extra_repr(self): 196 | s = ('scope_size={scope_size}, in_channels={in_channels}, ' 197 | 'kernel_size={kernel_size}, stride={stride}') 198 | if self.padding != (0,) * len(self.padding): 199 | s += ', padding={padding}' 200 | if self.dilation != (1,) * len(self.dilation): 201 | s += ', dilation={dilation}' 202 | if self.rotation_groups != 1: 203 | s += ', rotation_groups={rotation_groups}' 204 | if self.bias is None: 205 | s += ', bias=False' 206 | return s.format(**self.__dict__) 207 | -------------------------------------------------------------------------------- /deformable_kernels/ops/deform_kernel/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : setup.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # Date : 01/17/2020 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | from setuptools import setup 11 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 12 | 13 | setup( 14 | name='filter_sample_depthwise', 15 | ext_modules=[ 16 | CUDAExtension( 17 | 'filter_sample_depthwise_cuda', 18 | [ 19 | 'csrc/filter_sample_depthwise_cuda.cpp', 20 | 'csrc/filter_sample_depthwise_cuda_kernel.cu', 21 | ] 22 | ), 23 | ], 24 | cmdclass={'build_ext': BuildExtension} 25 | ) 26 | 27 | setup( 28 | name="nd_linear_sample", 29 | ext_modules=[ 30 | CUDAExtension( 31 | "nd_linear_sample_cuda", 32 | [ 33 | "csrc/nd_linear_sample_cuda.cpp", 34 | "csrc/nd_linear_sample_cuda_kernel.cu", 35 | ], 36 | ) 37 | ], 38 | cmdclass={"build_ext": BuildExtension}, 39 | ) 40 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deformable-kernels 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.6 7 | - pytorch=1.3.1 8 | - torchvision=0.4.2 9 | - cudatoolkit=10 10 | --------------------------------------------------------------------------------