├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── depthwise_conv3d.py ├── grad_test.py ├── setup.cfg ├── setup.py └── src ├── deformable_conv.cu └── warp.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xin Qiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/config.h 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerated 3D Depthwise Convolution 2 | 3 | This is seperate repo of my pull request (Accelerated 3D Depthwise Convolution), which is part of Pytorch 1.9. 4 | This repo aim to support other people want to use the module without upgrade to latest cudnn or pytorch. 5 | 6 | ## Installation 7 | 8 | prerequisite: 9 | 10 | - Pytorch >= 1.6 11 | - Python3 12 | 13 | ``` bash 14 | python setup.py install 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | 21 | import torch 22 | from depthwise_conv3d import DepthwiseConv3d 23 | 24 | dtype = torch.float 25 | conv = DepthwiseConv3d(2, 2, kernel_size=3, groups=2).to("cuda", dtype) 26 | input = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() 27 | output = conv(input) 28 | 29 | ``` -------------------------------------------------------------------------------- /depthwise_conv3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import DWCONV_CUDA 4 | import torch 5 | from torch import nn 6 | from torch.cuda.amp import custom_bwd, custom_fwd 7 | from torch.nn.modules.utils import _triple 8 | 9 | 10 | class DepthwiseConv3dFunction(torch.autograd.Function): 11 | @staticmethod 12 | @custom_fwd 13 | def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, 14 | groups=1): 15 | ctx.stride = _triple(stride) 16 | ctx.padding = _triple(padding) 17 | ctx.dilation = _triple(dilation) 18 | ctx.kernel_size = _triple(weight.shape[2]) 19 | ctx.groups = groups 20 | ctx.with_bias = bias is not None 21 | if not ctx.with_bias: 22 | bias = input.new_empty(0) # fake tensor 23 | if not input.is_cuda: 24 | raise NotImplementedError 25 | if weight.requires_grad or input.requires_grad: 26 | ctx.save_for_backward(input, weight, bias) 27 | weight = weight.to(input.dtype) 28 | bias = bias.to(input.dtype) 29 | output = DWCONV_CUDA.conv_depthwise3d_cuda( 30 | input, weight, ctx.kernel_size, bias, 31 | ctx.stride, 32 | ctx.padding, 33 | ctx.dilation) 34 | 35 | return output 36 | 37 | @staticmethod 38 | @custom_bwd 39 | def backward(ctx, grad_output): 40 | grad_output = grad_output.contiguous() 41 | if not grad_output.is_cuda: 42 | raise NotImplementedError 43 | input, weight, bias = ctx.saved_tensors 44 | grad_input = torch.zeros_like(input) 45 | grad_weight = torch.zeros_like(weight) 46 | grad_input, grad_weight, grad_bias = DWCONV_CUDA.conv_depthwise3d_backward_cuda( 47 | grad_output, grad_input, grad_weight, 48 | ctx.kernel_size, 49 | ctx.stride, 50 | ctx.padding, 51 | ctx.dilation, (True, True, True)) 52 | return grad_input, grad_weight, grad_bias, None, None, None, None, None, None 53 | 54 | 55 | depthwise_conv3d = DepthwiseConv3dFunction.apply 56 | 57 | 58 | class DepthwiseConv3d(nn.Module): 59 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 60 | groups=1, bias=True): 61 | super(DepthwiseConv3d, self).__init__() 62 | assert in_channels % groups == 0, \ 63 | 'in_channels {} cannot be divisible by groups {}'.format( 64 | in_channels, groups) 65 | assert out_channels % groups == 0, \ 66 | 'out_channels {} cannot be divisible by groups {}'.format( 67 | out_channels, groups) 68 | 69 | self.in_channels = in_channels 70 | self.out_channels = out_channels 71 | self.kernel_size = _triple(kernel_size) 72 | self.stride = _triple(stride) 73 | self.padding = _triple(padding) 74 | self.dilation = _triple(dilation) 75 | self.groups = groups 76 | 77 | self.weight = nn.Parameter( 78 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) 79 | self.with_bias = bias 80 | if self.with_bias: 81 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 82 | else: 83 | self.bias = None 84 | 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | n = self.in_channels 89 | for k in self.kernel_size: 90 | n *= k 91 | stdv = 1. / math.sqrt(n) 92 | self.weight.data.uniform_(-stdv, stdv) 93 | if self.with_bias: 94 | self.bias.data.fill_(0) 95 | def forward(self, x): 96 | return depthwise_conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, 97 | self.groups, ) 98 | -------------------------------------------------------------------------------- /grad_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing._internal.common_utils import TestCase 3 | from torch.testing._internal.common_utils import dtype2prec_DONTUSE 4 | 5 | from depthwise_conv3d import DepthwiseConv3d 6 | 7 | 8 | class TestConv(TestCase): 9 | def test_Conv3d_depthwise_naive_groups_cuda(self, dtype=torch.float): 10 | for depth_multiplier in [1, 2]: 11 | m = DepthwiseConv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to("cuda", dtype) 12 | i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() 13 | output = m(i) 14 | grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device="cuda", dtype=dtype) / 2 15 | output.backward(grad_output) 16 | 17 | offset = 1 * depth_multiplier 18 | 19 | m1 = DepthwiseConv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype) 20 | m1.weight.data = m.weight.data[:offset].clone() 21 | m1.bias.data = m.bias.data[:offset].clone() 22 | i1 = i.detach()[:, :1].clone().requires_grad_() 23 | output1 = m1(i1) 24 | output1.backward(grad_output[:, :offset].contiguous()) 25 | 26 | m2 = DepthwiseConv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype) 27 | m2.weight.data.copy_(m.weight.data[offset:]) 28 | m2.bias.data.copy_(m.bias.data[offset:]) 29 | i2 = i.detach()[:, 1:].clone().requires_grad_() 30 | output2 = m2(i2) 31 | output2.backward(grad_output[:, offset:].contiguous()) 32 | 33 | self.assertEqual(output, torch.cat([output1, output2], 1), 34 | atol=dtype2prec_DONTUSE[dtype], rtol=0) 35 | self.assertEqual(i.grad.data, 36 | torch.cat([i1.grad.data, i2.grad.data], 1), 37 | atol=dtype2prec_DONTUSE[dtype], rtol=0) 38 | self.assertEqual(m.bias.grad.data, 39 | torch.cat([m1.bias.grad.data, 40 | m2.bias.grad.data], 0), 41 | atol=dtype2prec_DONTUSE[dtype], rtol=0) 42 | self.assertEqual(m.weight.grad.data, 43 | torch.cat([m1.weight.grad.data, 44 | m2.weight.grad.data], 0), 45 | atol=dtype2prec_DONTUSE[dtype], rtol=0) 46 | 47 | 48 | if __name__ == '__main__': 49 | test = TestConv() 50 | # test.test_Conv3d_depthwise_naive_groups_cuda() 51 | test.grad_check() 52 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | desciption-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | 6 | def make_cuda_ext(name, sources, includes): 7 | return CUDAExtension( 8 | name='{}'.format(name), 9 | sources=[p for p in sources], 10 | include_dirs=[i for i in includes], 11 | extra_compile_args={ 12 | 'cxx': [], 13 | 'nvcc': [ 14 | '-D__CUDA_NO_HALF_OPERATORS__', 15 | '-D__CUDA_NO_HALF_CONVERSIONS__', 16 | '-D__CUDA_NO_HALF2_OPERATORS__', 17 | ]}) 18 | 19 | 20 | # -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D_GLIBCXX_USE_CXX11_ABI=1 21 | sources = [] 22 | sources.extend(glob.glob('src/*.cu')) 23 | sources.extend(glob.glob('src/*.cpp')) 24 | 25 | setup( 26 | name='depthwise_conv3d', 27 | version='1.0.3', 28 | author='gungui98', 29 | author_email='phi.nguyen.uet@gmail.com', 30 | url='https://www.github.com', 31 | description="cuda implementation of 3d depthwise convolution", 32 | ext_modules=[ 33 | make_cuda_ext(name='DWCONV_CUDA', 34 | sources=sources, 35 | includes=['src']) 36 | ], 37 | py_modules=['depthwise_conv3d'], 38 | classifiers=( 39 | 'Development Status :: 3 - Alpha', 40 | 'Operating System :: POSIX :: Linux', 41 | 'Intended Audience :: Developers', 42 | 43 | 'License :: OSI Approved :: MIT License', 44 | 45 | 'Programming Language :: Python :: 3.6', 46 | ), 47 | install_requires=['torch>=1.6'], 48 | keywords=["pytorch", "cuda", "depthwise convolution"], 49 | cmdclass={'build_ext': BuildExtension}, zip_safe=False) 50 | -------------------------------------------------------------------------------- /src/deformable_conv.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace at; 13 | using namespace native; 14 | 15 | template 18 | __global__ void conv_depthwise3d_cuda_kernel( 19 | const PackedTensorAccessor32 input, 20 | PackedTensorAccessor32 output, 21 | const PackedTensorAccessor32 kernel, 22 | const scalar_t* bias, 23 | int strideT, int strideH, int strideW, 24 | int paddingT, int paddingH, int paddingW, 25 | int dilationT_, int dilationH_, int dilationW_) 26 | { 27 | const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2); 28 | const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3); 29 | const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4); 30 | const int oC = output.size(1); 31 | const int oT = output.size(2); 32 | const int oH = output.size(3); 33 | const int oW = output.size(4); 34 | const int iC = input.size(1); 35 | const int iT = input.size(2); 36 | const int iH = input.size(3); 37 | const int iW = input.size(4); 38 | const int channel_multiplier = oC / iC; 39 | const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_; 40 | const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_; 41 | const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_; 42 | const int num_output = output.size(0) * output.stride(0); 43 | 44 | CUDA_KERNEL_LOOP(index, num_output) { 45 | const int out_col = index % oW; 46 | const int out_row = (index / oW) % oH; 47 | const int out_frame = (index / oW / oH) % oT; 48 | const int out_channel = (index / oW / oH / oT) % oC; 49 | const int batch = index / oW / oH / oT / oC; 50 | 51 | const int in_channel = out_channel / channel_multiplier; 52 | 53 | const int in_col_start = out_col * strideW - paddingW; 54 | const int in_row_start = out_row * strideH - paddingH; 55 | const int in_frame_start = out_frame * strideT - paddingT; 56 | 57 | accscalar_t sum = 0; 58 | const scalar_t *kernel_ptr = kernel[out_channel].data(); 59 | const scalar_t *input_ptr = 60 | &input[batch][in_channel][in_frame_start][in_row_start][in_col_start]; 61 | for (int k_frame = 0; k_frame < kT; ++k_frame) { 62 | const int in_frame = in_frame_start + k_frame * dilationT; 63 | for (int k_row = 0; k_row < kH; ++k_row) { 64 | const int in_row = in_row_start + k_row * dilationH; 65 | for (int k_col = 0; k_col < kW; ++k_col) { 66 | const accscalar_t op1 = *(kernel_ptr++); 67 | const int in_col = in_col_start + k_col * dilationW; 68 | if (in_frame >= 0 && in_row >= 0 && in_col >= 0 && 69 | in_frame < iT && in_row < iH && in_col < iW) { 70 | sum += op1 * *(input_ptr); 71 | } 72 | input_ptr += dilationW; 73 | } 74 | input_ptr += iW * dilationH - kW * dilationW; 75 | } 76 | input_ptr += iW * (iH * dilationT - kH * dilationH); 77 | } 78 | if (bias != NULL) { 79 | sum += bias[out_channel]; 80 | } 81 | 82 | output[batch][out_channel][out_frame][out_row][out_col] = sum; 83 | } 84 | } 85 | 86 | template 90 | __global__ void 91 | conv_depthwise3d_cuda_backward_input_kernel( 92 | const PackedTensorAccessor32 grad_output, 93 | PackedTensorAccessor32 grad_input, 94 | const PackedTensorAccessor32 kernel, 95 | int strideT_, int strideH_, int strideW_, 96 | int paddingT, int paddingH, int paddingW, 97 | int dilationT_, int dilationH_, int dilationW_) { 98 | const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2); 99 | const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3); 100 | const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4); 101 | const int oC = grad_output.size(1); 102 | const int oT = grad_output.size(2); 103 | const int oH = grad_output.size(3); 104 | const int oW = grad_output.size(4); 105 | const int iC = grad_input.size(1); 106 | const int iT = grad_input.size(2); 107 | const int iH = grad_input.size(3); 108 | const int iW = grad_input.size(4); 109 | const int channel_multiplier = oC / iC; 110 | const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_; 111 | const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_; 112 | const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_; 113 | const int strideT = kKnownStrideT > 0 ? kKnownStrideT : strideT_; 114 | const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_; 115 | const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_; 116 | const int num_input = grad_input.size(0) * grad_input.stride(0); 117 | 118 | CUDA_KERNEL_LOOP(index, num_input) { 119 | const int in_col = index % iW; 120 | const int in_row = (index / iW) % iH; 121 | const int in_frame = (index / iW / iH) % iT; 122 | const int in_channel = (index / iW / iH / iT) % iC; 123 | const int batch = index / iW / iH / iT / iC; 124 | 125 | const int out_col_end = in_col + paddingW; 126 | const int out_row_end = in_row + paddingH; 127 | const int out_frame_end = in_frame + paddingT; 128 | 129 | const scalar_t* kernel_ptr = kernel[in_channel * channel_multiplier].data(); 130 | accscalar_t sum = 0; 131 | 132 | for (int k_chn = in_channel * channel_multiplier; 133 | k_chn < (in_channel + 1) * channel_multiplier; 134 | ++k_chn) { 135 | const scalar_t* gout_ptr = grad_output[batch][k_chn].data(); 136 | 137 | for (int k_frame = 0; k_frame < kT; ++k_frame) { 138 | const int out_frame_raw = out_frame_end - k_frame * dilationT; 139 | const int out_frame = out_frame_raw / strideT; 140 | for (int k_row = 0; k_row < kH; ++k_row) { 141 | const int out_row_raw = out_row_end - k_row * dilationH; 142 | const int out_row = out_row_raw / strideH; 143 | for (int k_col = 0; k_col < kW; ++k_col) { 144 | const accscalar_t op1 = *(kernel_ptr++); 145 | const int out_col_raw = out_col_end - k_col * dilationW; 146 | const int out_col = out_col_raw / strideW; 147 | 148 | const int out_offs = (out_frame * oH + out_row) * oW + out_col; 149 | 150 | accscalar_t op2 = (accscalar_t)0; 151 | if (out_col >= 0 && out_row >= 0 && out_frame >= 0 && 152 | out_col < oW && out_row < oH && out_frame < oT) { 153 | op2 = *(gout_ptr + out_offs); 154 | } 155 | if (out_frame * strideT == out_frame_raw && 156 | out_row * strideH == out_row_raw && 157 | out_col * strideW == out_col_raw) { 158 | sum += op1 * op2; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | 165 | grad_input[batch][in_channel][in_frame][in_row][in_col] = sum; 166 | } 167 | } 168 | 169 | template 171 | __global__ void 172 | conv_depthwise3d_cuda_backward_weight_kernel( 173 | const PackedTensorAccessor32 grad_output, 174 | const PackedTensorAccessor32 input, 175 | PackedTensorAccessor32 grad_kernel, 176 | int strideT, int strideH_, int strideW_, 177 | int paddingT, int paddingH, int paddingW, 178 | int dilationT, int dilationH, int dilationW) { 179 | const int kC = grad_kernel.size(0); 180 | const int kT = grad_kernel.size(2); 181 | const int kH = grad_kernel.size(3); 182 | const int kW = grad_kernel.size(4); 183 | 184 | const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_; 185 | const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_; 186 | 187 | const int k_col = blockIdx.x % kW; 188 | const int k_row = (blockIdx.x / kW) % kH; 189 | const int k_frame = (blockIdx.x / kW / kH) % kT; 190 | const int k_channel = blockIdx.x / kW / kH / kT; 191 | scalar_t *result = &grad_kernel[k_channel][0][k_frame][k_row][k_col]; 192 | 193 | const int oT = grad_output.size(2); 194 | const int oH = grad_output.size(3); 195 | const int oW = grad_output.size(4); 196 | const int iT = input.size(2); 197 | const int iH = input.size(3); 198 | const int iW = input.size(4); 199 | const int channel_multiplier = grad_output.size(1) / input.size(1); 200 | const int in_channel = k_channel / channel_multiplier; 201 | 202 | extern __shared__ int sdata_raw[]; 203 | scalar_t* sdata = reinterpret_cast(sdata_raw); 204 | 205 | if (k_channel >= kC) { 206 | return; 207 | } 208 | 209 | const int laneid = threadIdx.x % C10_WARP_SIZE; 210 | const int warpid = threadIdx.x / C10_WARP_SIZE; 211 | const int nwarps = blockDim.x / C10_WARP_SIZE; 212 | 213 | accscalar_t grad = 0; 214 | int batch = warpid / oT; 215 | int gout_frame = warpid - batch * oT; 216 | for (int outer_pos = warpid; outer_pos < input.size(0) * oT; 217 | outer_pos += nwarps, gout_frame += nwarps) { 218 | while (gout_frame >= oT) { gout_frame -= oT; batch ++; } 219 | 220 | const int in_frame = (gout_frame * strideT) + (k_frame * dilationT) - paddingT; 221 | 222 | if (in_frame < 0 || in_frame >= iT) { 223 | continue; 224 | } 225 | 226 | const scalar_t* gout_ptr = grad_output[batch][k_channel][gout_frame].data() + laneid; 227 | const scalar_t* input_ptr = input[batch][in_channel][in_frame].data(); 228 | 229 | int gout_row = laneid / oW; 230 | int gout_col = laneid - gout_row * oW; 231 | 232 | for (; gout_row < oH; ) { 233 | const accscalar_t op1 = *(gout_ptr); 234 | gout_ptr += C10_WARP_SIZE; 235 | 236 | const int in_col = (gout_col * strideW) + (k_col * dilationW) - paddingW; 237 | const int in_row = (gout_row * strideH) + (k_row * dilationH) - paddingH; 238 | const int in_pos = in_row * iW + in_col; 239 | 240 | accscalar_t op2 = (accscalar_t)0; 241 | if (in_col >= 0 && in_col < iW && in_row >= 0 && in_row < iH) { 242 | op2 = *(input_ptr + in_pos); 243 | } 244 | 245 | gout_col += C10_WARP_SIZE; 246 | while (gout_col >= oW) { 247 | gout_col -= oW; gout_row ++; 248 | } 249 | 250 | grad += op1 * op2; 251 | } 252 | } 253 | 254 | sdata[threadIdx.x] = grad; 255 | __syncthreads(); 256 | 257 | CUDA_KERNEL_ASSERT(__popc(blockDim.x) == 1); 258 | #pragma unroll 259 | for (int i = blockDim.x / 2; i >= 1; i >>= 1) { 260 | if (threadIdx.x < i) { 261 | sdata[threadIdx.x] += sdata[threadIdx.x + i]; 262 | } 263 | __syncthreads(); 264 | } 265 | 266 | if (threadIdx.x == 0) { 267 | *result = sdata[0]; 268 | } 269 | } 270 | 271 | template 272 | void conv_depthwise_shape_check( 273 | const Tensor& input, 274 | const Tensor& weight, 275 | const Tensor& bias, 276 | const Tensor& grad_output, 277 | IntArrayRef kernel_size, 278 | IntArrayRef stride, 279 | IntArrayRef padding, 280 | IntArrayRef dilation) { 281 | TORCH_CHECK(kernel_size.size() == dim, 282 | "kernel size length should be ", dim, ", but got ", kernel_size.size()); 283 | TORCH_CHECK(stride.size() == dim, 284 | "stride length should be ", dim, ", but got ", stride.size()); 285 | TORCH_CHECK(padding.size() == dim, 286 | "padding length should be ", dim, ", but got ", padding.size()); 287 | TORCH_CHECK(dilation.size() == dim, 288 | "dilation length should be ", dim, ", but got ", dilation.size()); 289 | 290 | TORCH_CHECK(weight.defined(), 291 | "Weight must be defined."); 292 | TORCH_CHECK(input.dim() == dim + 1 || input.dim() == dim + 2, 293 | "Input dimension should be ", 294 | dim + 1, "D or ", dim + 2, "D, got ", 295 | input.dim(), "D"); 296 | TORCH_CHECK(weight.dim() == dim + 2, 297 | "Weight dimension should be ", dim + 2, "D, got ", weight.dim(), "D"); 298 | TORCH_CHECK(weight.size(1) == 1, 299 | "Depthwise weight should have in_channels=1, got ", weight.size(1)); 300 | TORCH_CHECK(weight.size(0) % input.size(-dim - 1) == 0, 301 | "Depthwise out channels should be a multiple of in channels, got ", 302 | weight.size(0), " and ", input.size(-dim - 1)); 303 | for (int i = 0; i < dim; ++i) { 304 | TORCH_CHECK(weight.size(i + 2) == kernel_size[i], 305 | "kernel size and weight size mismatch, got ", 306 | kernel_size, " and ", weight.sizes()); 307 | TORCH_CHECK(stride[i] >= 1, 308 | "stride should be at least 1, got ", stride); 309 | TORCH_CHECK(padding[i] >= 0, 310 | "padding should be non-negative, got ", padding); 311 | TORCH_CHECK(dilation[i] >= 1, 312 | "dilation should be at least 1, got ", dilation); 313 | } 314 | 315 | if (bias.defined()) { 316 | TORCH_CHECK(bias.dim() == 1, 317 | "Bias should be 1D tensor, got ", bias.dim(), "D"); 318 | TORCH_CHECK(bias.size(0) == weight.size(0), 319 | "Bias length should be equal to out_channels, got ", 320 | bias.size(0), " and ", weight.size(0)); 321 | } 322 | 323 | if (grad_output.defined()) { 324 | auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(), 325 | padding, stride, dilation); 326 | TORCH_CHECK(grad_output.dim() == expected_output_size.size(), 327 | "Expect grad_output to be ", 328 | expected_output_size.size(), "D, got ", 329 | grad_output.dim(), "D."); 330 | for (int i = 0; i < grad_output.dim(); ++i) { 331 | TORCH_CHECK(grad_output.size(i) == expected_output_size[i], 332 | "Expect grad_output to be of same shape as output, got ", 333 | grad_output.size(i), " and ", expected_output_size[i], 334 | " at dimension ", i); 335 | } 336 | } 337 | } 338 | 339 | 340 | #define NODEF_OR_EQUAL(x, y) ((y) < 0 || (x) == (y)) 341 | #define NODEF_OR_EQUAL_3(x, y1, y2, y3) \ 342 | (NODEF_OR_EQUAL(x[0], y1) && \ 343 | NODEF_OR_EQUAL(x[1], y2) && \ 344 | NODEF_OR_EQUAL(x[2], y3)) 345 | 346 | #define DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(kt, kh, kw, dilt, dilh, dilw) \ 347 | if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \ 348 | NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw))) { \ 349 | using accscalar_t = acc_type; \ 350 | conv_depthwise3d_cuda_kernel \ 351 | \ 352 | <<>>( \ 353 | input_.packed_accessor32(), \ 354 | output_.packed_accessor32(), \ 355 | weight_.packed_accessor32(), \ 356 | bias_ptr, \ 357 | stride[0], stride[1], stride[2], \ 358 | padding[0], padding[1], padding[2], \ 359 | dilation[0], dilation[1], dilation[2]); \ 360 | } else 361 | 362 | #define DWCONV3D_FORWARD_DISPATCH_OTHERS \ 363 | { \ 364 | using accscalar_t = acc_type; \ 365 | conv_depthwise3d_cuda_kernel \ 366 | \ 367 | <<>>( \ 368 | input_.packed_accessor32(), \ 369 | output_.packed_accessor32(), \ 370 | weight_.packed_accessor32(), \ 371 | bias_ptr, \ 372 | stride[0], stride[1], stride[2], \ 373 | padding[0], padding[1], padding[2], \ 374 | dilation[0], dilation[1], dilation[2]); \ 375 | } 376 | 377 | Tensor conv_depthwise3d_cuda( 378 | const Tensor& input, 379 | const Tensor& weight, 380 | IntArrayRef kernel_size, 381 | const Tensor& bias, 382 | IntArrayRef stride, 383 | IntArrayRef padding, 384 | IntArrayRef dilation) { 385 | TORCH_CHECK(input.device() == weight.device(), "expects input and weight tensors to be on the same device."); 386 | if (bias.defined()) { 387 | TORCH_CHECK(input.device() == bias.device(), "expects input and bias tensors to be on the same device."); 388 | } 389 | 390 | conv_depthwise_shape_check<3>(input, weight, bias, Tensor() /* undefined */, 391 | kernel_size, stride, padding, dilation); 392 | 393 | Tensor input_ = input.contiguous(); 394 | 395 | if (input.dim() == 4 /* no batch */) { 396 | input_ = input.unsqueeze(0); 397 | } 398 | 399 | auto output_size = conv_output_size(input_.sizes(), weight.sizes(), 400 | padding, stride, dilation); 401 | for (size_t i = 0; i < output_size.size(); ++i) { 402 | TORCH_CHECK(output_size[i] > 0, 403 | "Output size should be positive, got ", output_size[i], " at dim ", i); 404 | } 405 | Tensor output = at::empty(output_size, input.options()); 406 | Tensor output_ = output; 407 | Tensor weight_ = weight.contiguous(); 408 | Tensor bias_ = bias.defined() ? bias.contiguous() : bias; 409 | 410 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 411 | input.scalar_type(), 412 | "conv_depthwise3d", 413 | [&]{ 414 | int64_t num_outputs = output_.numel(); 415 | int64_t block = 256; 416 | int64_t grid = std::min((num_outputs - 1) / block + 1, (int64_t)65536); 417 | int64_t smem = 0; 418 | 419 | const scalar_t* bias_ptr = 420 | bias_.defined() ? bias_.data_ptr() : NULL; 421 | 422 | // Range check to avoid overflow in CUDA kernels. 423 | TORCH_CHECK(input_.numel() <= std::numeric_limits::max(), 424 | "Input tensor is too large."); 425 | TORCH_CHECK(output_.numel() <= std::numeric_limits::max(), 426 | "Output tensor is too large."); 427 | TORCH_CHECK(weight_.numel() <= std::numeric_limits::max(), 428 | "Weight tensor is too large."); 429 | for (int i = 0; i < 3; ++i) { 430 | TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits::max(), 431 | "Padded input tensor is too large."); 432 | } 433 | 434 | DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(3, 3, 3, 1, 1, 1) 435 | DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(-1, -1, -1, 1, 1, 1) 436 | DWCONV3D_FORWARD_DISPATCH_OTHERS 437 | } 438 | ); 439 | 440 | return output; 441 | } 442 | 443 | #undef DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION 444 | #undef DWCONV3D_FORWARD_DISPATCH_OTHERS 445 | 446 | #define DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( \ 447 | kt, kh, kw, dilt, dilh, dilw, dt, dh, dw) \ 448 | if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \ 449 | NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw)) && \ 450 | NODEF_OR_EQUAL_3(stride, (dt), (dh), (dw))) { \ 451 | using accscalar_t = acc_type; \ 452 | conv_depthwise3d_cuda_backward_input_kernel \ 453 | \ 454 | <<>>( \ 455 | grad_output_.packed_accessor32(), \ 456 | grad_input_.packed_accessor32(), \ 457 | weight_.packed_accessor32(), \ 458 | stride[0], stride[1], stride[2], \ 459 | padding[0], padding[1], padding[2], \ 460 | dilation[0], dilation[1], dilation[2]); \ 461 | } else 462 | 463 | #define DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS \ 464 | { \ 465 | using accscalar_t = acc_type; \ 466 | conv_depthwise3d_cuda_backward_input_kernel \ 467 | \ 468 | <<>>( \ 469 | grad_output_.packed_accessor32(), \ 470 | grad_input_.packed_accessor32(), \ 471 | weight_.packed_accessor32(), \ 472 | stride[0], stride[1], stride[2], \ 473 | padding[0], padding[1], padding[2], \ 474 | dilation[0], dilation[1], dilation[2]); \ 475 | } 476 | 477 | #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(dh, dw) \ 478 | if (NODEF_OR_EQUAL_3(stride, -1, (dh), (dw))) { \ 479 | using accscalar_t = acc_type; \ 480 | conv_depthwise3d_cuda_backward_weight_kernel \ 481 | \ 482 | <<>>( \ 483 | grad_output_.packed_accessor32(), \ 484 | input_.packed_accessor32(), \ 485 | grad_weight.packed_accessor32(), \ 486 | stride[0], stride[1], stride[2], \ 487 | padding[0], padding[1], padding[2], \ 488 | dilation[0], dilation[1], dilation[2]); \ 489 | } else 490 | 491 | #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS \ 492 | { \ 493 | using accscalar_t = acc_type; \ 494 | conv_depthwise3d_cuda_backward_weight_kernel \ 495 | \ 496 | <<>>( \ 497 | grad_output_.packed_accessor32(), \ 498 | input_.packed_accessor32(), \ 499 | grad_weight.packed_accessor32(), \ 500 | stride[0], stride[1], stride[2], \ 501 | padding[0], padding[1], padding[2], \ 502 | dilation[0], dilation[1], dilation[2]); \ 503 | } 504 | 505 | std::tuple _depthwise_3d_backward_cuda_out( 506 | Tensor& grad_input, 507 | Tensor& grad_weight, 508 | Tensor& grad_bias, 509 | const Tensor& grad_output, 510 | const Tensor& input, 511 | const Tensor& weight, 512 | IntArrayRef kernel_size, 513 | IntArrayRef stride, 514 | IntArrayRef padding, 515 | IntArrayRef dilation, 516 | const std::array output_mask) 517 | { 518 | 519 | TORCH_CHECK(grad_output.device() == input.device() && 520 | input.device() == weight.device(), 521 | "expects input, weight and grad_output to be on the same device."); 522 | conv_depthwise_shape_check<3>( 523 | input, weight, Tensor() /* undefined */, grad_output, 524 | kernel_size, stride, padding, dilation); 525 | 526 | const Tensor grad_output_ = grad_output.contiguous(); 527 | const Tensor input_ = input.contiguous(); 528 | const Tensor weight_ = weight.contiguous(); 529 | 530 | Tensor grad_input_ = 531 | (output_mask[0] ? grad_input 532 | : Tensor()); 533 | 534 | if (output_mask[0]) { 535 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 536 | grad_output.scalar_type(), 537 | "conv_depthwise3d", 538 | [&] { 539 | int64_t num_inputs = grad_input_.numel(); 540 | int64_t block = 256; 541 | int64_t grid = std::min((num_inputs - 1) / block + 1, (int64_t)65536); 542 | 543 | // Range check to avoid overflow in CUDA kernels. 544 | TORCH_CHECK(grad_input_.numel() <= std::numeric_limits::max(), 545 | "Input tensor is too large."); 546 | TORCH_CHECK(grad_output_.numel() <= std::numeric_limits::max(), 547 | "Output tensor is too large."); 548 | TORCH_CHECK(weight_.numel() <= std::numeric_limits::max(), 549 | "Weight tensor is too large."); 550 | for (int i = 0; i < 3; ++i) { 551 | TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits::max(), 552 | "Padded input tensor is too large."); 553 | } 554 | 555 | DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( 556 | 3, 3, 3, 1, 1, 1, 1, 1, 1) 557 | DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( 558 | 3, 3, 3, 1, 1, 1, -1, -1, -1) 559 | DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( 560 | 3, 3, 3, -1, -1, -1, 1, 1, 1) 561 | DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( 562 | 3, 3, 3, -1, -1, -1, -1, -1, -1) 563 | DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS 564 | } 565 | ); 566 | } 567 | 568 | if (output_mask[1]) { 569 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 570 | grad_output.scalar_type(), 571 | "conv_depthwise3d", 572 | [&] { 573 | int64_t grid = grad_weight.numel(); 574 | int64_t block = 256; 575 | int64_t smem = sizeof(scalar_t) * block; 576 | 577 | const int64_t int_max = std::numeric_limits::max(); 578 | TORCH_CHECK(grad_input_.numel() <= int_max, 579 | "Input tensor is too large."); 580 | TORCH_CHECK(grad_output_.numel() <= int_max, 581 | "Output tensor is too large."); 582 | TORCH_CHECK(weight_.numel() <= int_max, 583 | "Weight tensor is too large."); 584 | for (int i = 0; i < 3; ++i) { 585 | TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= int_max, 586 | "Padded input tensor is too large."); 587 | } 588 | TORCH_CHECK(grad_output_.size(0) * grad_output_.size(2) < int_max - block / C10_WARP_SIZE && 589 | grad_output_.size(3) <= int_max - C10_WARP_SIZE && 590 | grad_output_.size(4) <= int_max - C10_WARP_SIZE, 591 | "Output size is too large."); 592 | 593 | DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(1, 1) 594 | DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(2, 2) 595 | DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS 596 | } 597 | ); 598 | } 599 | 600 | if (output_mask[2]) { 601 | grad_bias = grad_output.sum({0, 2, 3, 4}); 602 | } 603 | 604 | return std::tie(grad_input, grad_weight, grad_bias); 605 | 606 | } 607 | 608 | 609 | std::tuple conv_depthwise3d_backward_cuda_out( 610 | Tensor& grad_input, 611 | Tensor& grad_weight, 612 | Tensor& grad_bias, 613 | const Tensor& grad_output, 614 | const Tensor& input, 615 | const Tensor& weight, 616 | IntArrayRef kernel_size, 617 | IntArrayRef stride, 618 | IntArrayRef padding, 619 | IntArrayRef dilation) { 620 | if (grad_weight.defined()) { 621 | grad_weight.resize_(weight.sizes()); 622 | grad_weight.zero_(); 623 | } 624 | 625 | return _depthwise_3d_backward_cuda_out( 626 | grad_input, 627 | grad_weight, 628 | grad_bias, 629 | grad_output, 630 | input, 631 | weight, 632 | kernel_size, 633 | stride, 634 | padding, 635 | dilation, 636 | {true,true,true}); 637 | } 638 | 639 | std::tuple conv_depthwise3d_backward_cuda( 640 | const Tensor& grad_output, 641 | const Tensor& input, 642 | const Tensor& weight, 643 | IntArrayRef kernel_size, 644 | IntArrayRef stride, 645 | IntArrayRef padding, 646 | IntArrayRef dilation, 647 | const std::array output_mask) { 648 | 649 | auto options = grad_output.options(); 650 | Tensor grad_input = 651 | (output_mask[0] ? at::empty(input.sizes(), options) : Tensor()); 652 | Tensor grad_weight = 653 | (output_mask[1] ? at::empty(weight.sizes(), options) : Tensor()); 654 | Tensor grad_bias; /* undefined temporarily */ 655 | 656 | return _depthwise_3d_backward_cuda_out( 657 | grad_input, 658 | grad_weight, 659 | grad_bias, 660 | grad_output, 661 | input, 662 | weight, 663 | kernel_size, 664 | stride, 665 | padding, 666 | dilation, 667 | output_mask 668 | ); 669 | 670 | } 671 | 672 | #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION 673 | #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS 674 | 675 | #undef NODEF_OR_EQUAL_3 676 | #undef NODEF_OR_EQUAL -------------------------------------------------------------------------------- /src/warp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace at; 10 | std::tuple conv_depthwise3d_backward_cuda( 11 | const Tensor& grad_output, 12 | const Tensor& input, 13 | const Tensor& weight, 14 | IntArrayRef kernel_size, 15 | IntArrayRef stride, 16 | IntArrayRef padding, 17 | IntArrayRef dilation, 18 | const std::array output_mask); 19 | 20 | Tensor conv_depthwise3d_cuda( 21 | const Tensor& input, 22 | const Tensor& weight, 23 | IntArrayRef kernel_size, 24 | const Tensor& bias, 25 | IntArrayRef stride, 26 | IntArrayRef padding, 27 | IntArrayRef dilation); 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("conv_depthwise3d_backward_cuda", &conv_depthwise3d_backward_cuda, 31 | "conv_depthwise3d_backward_cuda"); 32 | m.def("conv_depthwise3d_cuda", &conv_depthwise3d_cuda, 33 | "conv_depthwise3d_cuda"); 34 | } 35 | --------------------------------------------------------------------------------