├── .gitignore ├── involution ├── __init__.py └── involution2d.py ├── docker └── run-docker.sh ├── include ├── involution2d_cpu.h ├── involution2d_cuda.cuh └── involution2d_wrapper.h ├── src ├── pytorch_wrapper.cpp ├── involution2d_cpu.cpp └── involution2d_cuda.cu ├── setup.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build/ 3 | -------------------------------------------------------------------------------- /involution/__init__.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | from torch import ops 5 | 6 | _LIB_PATH = glob(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'involution.*.so'))[0] 7 | ops.load_library(_LIB_PATH) 8 | 9 | from .involution2d import Involution2d 10 | -------------------------------------------------------------------------------- /docker/run-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN_DIR=$(dirname $(readlink -f $0)) 3 | 4 | DOCKER_VOLUME="${DOCKER_VOLUME} -v $(dirname ${RUN_DIR}):/workspace/involution:rw" 5 | 6 | docker run \ 7 | -it \ 8 | --rm \ 9 | --gpus '"device=0"' \ 10 | ${DOCKER_VOLUME} \ 11 | --name Involution-PyTorch \ 12 | pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel bash 13 | # pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel bash 14 | # pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel bash 15 | # pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash 16 | # nvcr.io/nvidia/pytorch:21.05-py3 17 | # nvcr.io/nvidia/pytorch:20.08-py3 18 | -------------------------------------------------------------------------------- /include/involution2d_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace involution { 7 | namespace cpu { 8 | 9 | at::Tensor involution2d_forward( 10 | const at::Tensor& input, 11 | const at::Tensor& weight, 12 | const std::vector& kernel_size, 13 | const std::vector& stride, 14 | const std::vector& padding, 15 | const std::vector& dilation, 16 | const int64_t groups 17 | ); 18 | 19 | at::Tensor involution2d_backward_grad_input( 20 | const at::Tensor& grad, 21 | const at::Tensor& weight, 22 | const std::vector& input_shape, 23 | const std::vector& kernel_size, 24 | const std::vector& stride, 25 | const std::vector& padding, 26 | const std::vector& dilation, 27 | const int64_t groups 28 | ); 29 | 30 | at::Tensor involution2d_backward_grad_weight( 31 | const at::Tensor& grad, 32 | const at::Tensor& input, 33 | const std::vector& weight_shape, 34 | const std::vector& kernel_size, 35 | const std::vector& stride, 36 | const std::vector& padding, 37 | const std::vector& dilation, 38 | const int64_t groups 39 | ); 40 | 41 | std::vector involution2d_backward( 42 | const at::Tensor& grad, 43 | const at::Tensor& weight, 44 | const at::Tensor& input, 45 | const std::vector& kernel_size, 46 | const std::vector& stride, 47 | const std::vector& padding, 48 | const std::vector& dilation, 49 | const int64_t groups 50 | ); 51 | 52 | } // namespace cpu 53 | } // namespace involution 54 | -------------------------------------------------------------------------------- /include/involution2d_cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace involution { 8 | namespace cuda { 9 | 10 | #define CUDA_MAX_THREADS 1024u 11 | 12 | #define CUDA_KERNEL_LOOP(i, n) \ 13 | for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 14 | 15 | at::Tensor involution2d_forward( 16 | const at::Tensor& input, 17 | const at::Tensor& weight, 18 | const std::vector& kernel_size, 19 | const std::vector& stride, 20 | const std::vector& padding, 21 | const std::vector& dilation, 22 | const int64_t groups 23 | ); 24 | 25 | at::Tensor involution2d_backward_grad_input( 26 | const at::Tensor& grad, 27 | const at::Tensor& weight, 28 | const std::vector& input_shape, 29 | const std::vector& kernel_size, 30 | const std::vector& stride, 31 | const std::vector& padding, 32 | const std::vector& dilation, 33 | const int64_t groups 34 | ); 35 | 36 | at::Tensor involution2d_backward_grad_weight( 37 | const at::Tensor& grad, 38 | const at::Tensor& input, 39 | const std::vector& weight_shape, 40 | const std::vector& kernel_size, 41 | const std::vector& stride, 42 | const std::vector& padding, 43 | const std::vector& dilation, 44 | const int64_t groups 45 | ); 46 | 47 | std::vector involution2d_backward( 48 | const at::Tensor& grad, 49 | const at::Tensor& weight, 50 | const at::Tensor& input, 51 | const std::vector& kernel_size, 52 | const std::vector& stride, 53 | const std::vector& padding, 54 | const std::vector& dilation, 55 | const int64_t groups 56 | ); 57 | 58 | } // namespace cuda 59 | } // namespace involution 60 | -------------------------------------------------------------------------------- /src/pytorch_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "involution2d_wrapper.h" 3 | 4 | TORCH_LIBRARY(involution, m) { 5 | m.def("involution2d(Tensor input, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); 6 | m.def("_involution2d_backward_grad_input(Tensor grad, Tensor weight, int[] input_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); 7 | m.def("_involution2d_backward_grad_weight(Tensor grad, Tensor input, int[] weight_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); 8 | m.def("_involution2d_backward(Tensor grad, Tensor weight, Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor[]"); 9 | } 10 | 11 | TORCH_LIBRARY_IMPL(involution, CPU, m) { 12 | m.impl("involution2d", involution::cpu::involution2d_forward); 13 | m.impl("_involution2d_backward_grad_input", involution::cpu::involution2d_backward_grad_input); 14 | m.impl("_involution2d_backward_grad_weight", involution::cpu::involution2d_backward_grad_weight); 15 | m.impl("_involution2d_backward", involution::cpu::involution2d_backward); 16 | } 17 | 18 | #ifdef USE_CUDA 19 | TORCH_LIBRARY_IMPL(involution, CUDA, m) { 20 | m.impl("involution2d", involution::cuda::involution2d_forward); 21 | m.impl("_involution2d_backward_grad_input", involution::cuda::involution2d_backward_grad_input); 22 | m.impl("_involution2d_backward_grad_weight", involution::cuda::involution2d_backward_grad_weight); 23 | m.impl("_involution2d_backward", involution::cuda::involution2d_backward); 24 | } 25 | #endif 26 | 27 | TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) { 28 | m.impl("involution2d", involution::cpu::involution2d_autograd); 29 | } 30 | 31 | #ifdef USE_CUDA 32 | TORCH_LIBRARY_IMPL(involution, AutogradCUDA, m) { 33 | m.impl("involution2d", involution::cuda::involution2d_autograd); 34 | } 35 | 36 | TORCH_LIBRARY_IMPL(involution, Autocast, m) { 37 | m.impl("involution2d", involution::cuda::involution2d_autocast); 38 | } 39 | #endif 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import abspath, dirname, join 3 | from setuptools import setup, find_packages 4 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension 5 | 6 | INCLUDE_DIR = join(dirname(abspath(__file__)), 'include') 7 | EXTRA_COMPILE_ARGS = ['-O3'] 8 | 9 | EXTENSION = [] 10 | 11 | CC = ['52', '53', '60', '61', '62', '70', '72', '75', '80'] 12 | 13 | if os.getenv('USE_OPENMP', '1') == '1': 14 | EXTRA_COMPILE_ARGS.append('-fopenmp') 15 | 16 | if os.getenv('USE_CUDA', '1') == '1': 17 | EXTRA_COMPILE_ARGS.append('-DUSE_CUDA') 18 | 19 | GENERATE_CODES = [] 20 | 21 | for cc in CC: 22 | GENERATE_CODES.append('--generate-code') 23 | GENERATE_CODES.append(f'arch=compute_{cc},code=compute_{cc}') 24 | 25 | EXTENSION.append( 26 | CUDAExtension( 27 | name='involution', 28 | sources=[ 29 | 'src/involution2d_cpu.cpp', 30 | 'src/involution2d_cuda.cu', 31 | 'src/pytorch_wrapper.cpp', 32 | ], 33 | include_dirs=[ 34 | INCLUDE_DIR 35 | ], 36 | extra_compile_args={ 37 | 'cxx': EXTRA_COMPILE_ARGS, 38 | 'nvcc': ['-O3'] + GENERATE_CODES, 39 | } 40 | ) 41 | ) 42 | else: 43 | EXTENSION.append( 44 | CppExtension( 45 | name='involution', 46 | sources=[ 47 | 'src/involution2d_cpu.cpp', 48 | 'src/pytorch_wrapper.cpp', 49 | ], 50 | include_dirs=[ 51 | INCLUDE_DIR 52 | ], 53 | extra_compile_args=EXTRA_COMPILE_ARGS 54 | ) 55 | ) 56 | 57 | setup( 58 | name='involution-pytorch', 59 | version="0.1.0", 60 | url="https://github.com/shikishima-TasakiLab/Involution-PyTorch", 61 | license="MIT License", 62 | author="Junya Shikishima", 63 | author_email="160442065@ccalumni.meijo-u.ac.jp", 64 | description="PyTorch Involution", 65 | packages=find_packages(), 66 | ext_modules=EXTENSION, 67 | cmdclass={ 68 | 'build_ext': BuildExtension, 69 | } 70 | ) 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Involution: Inverting the Inherence of Convolution for Visual Recognition 2 | 3 | Unofficial PyTorch reimplemention of the paper [Involution: Inverting the Inherence of Convolution for Visual Recognition](https://arxiv.org/pdf/2103.06255.pdf) by Duo Li, Jie Hu, Changhu Wang et al. published at CVPR 2021. 4 | 5 | **This repository includes a PyTorch implementation of 2D Involution using C++/OpenMP/CUDA.** 6 | 7 | ## Installation 8 | 9 | - Default (Use CUDA and OpenMP) 10 | 11 | ```bash 12 | pip install git+https://github.com/shikishima-TasakiLab/Involution-PyTorch 13 | ``` 14 | 15 | - No CUDA 16 | 17 | ```bash 18 | USE_CUDA=0 pip install git+https://github.com/shikishima-TasakiLab/Involution-PyTorch 19 | ``` 20 | 21 | - No OpenMP 22 | 23 | ```bash 24 | USE_OPENMP=0 pip install git+https://github.com/shikishima-TasakiLab/Involution-PyTorch 25 | ``` 26 | 27 | ## Example Usage 28 | 29 | The 2D involution can be used as a `nn.Module` as follows: 30 | 31 | ```python 32 | import torch 33 | import torch.nn as nn 34 | from involution import Involution2d 35 | 36 | if torch.cuda.is_available(): 37 | device = torch.device("cuda:0") 38 | else: 39 | device = torch.device("cpu") 40 | 41 | inv2d: nn.Module = Involution2d(in_channels=4, out_channels=8).to(device) 42 | 43 | x: torch.Tensor = torch.rand(2, 4, 8, 8).to(device) 44 | 45 | y: torch.Tensor = inv2d(x) 46 | ``` 47 | 48 | The 2D involution takes the following parameters: 49 | 50 | |Parameter |Description |Type |Default| 51 | |---------------|-------------------------------------------------------------------------------|-------------------|-------| 52 | |`in_channels` |Number of input channels. |`int` | - | 53 | |`out_channels` |Number of output channels. |`int` | - | 54 | |`kernel_size` |Kernel size to be used. |`int`, `(int, int)`|`7` | 55 | |`stride` |Stride factor to be utilized. |`int`, `(int, int)`|`1` | 56 | |`padding` |Padding to be used in unfold operation. |`int`, `(int, int)`|`3` | 57 | |`dilation` |Dilation in unfold to be employed. |`int`, `(int, int)`|`1` | 58 | |`groups` |Number of groups to be employed. |`int` |`1` | 59 | |`bias` |If true bias is utilized in each convolution layer. |`bool` |`False`| 60 | |`sigma_mapping`|Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized. |`torch.nn.Module` |`None` | 61 | |`reduce_ratio` |Reduce ration of involution channels. |`int` |`1` | 62 | 63 | ## Reference 64 | 65 | ```bibtex 66 | @inproceedings{Li2021, 67 | author = {Li, Duo and Hu, Jie and Wang, Changhu and Li, Xiangtai and She, Qi and Zhu, Lei and Zhang, Tong and Chen, Qifeng}, 68 | title = {Involution: Inverting the Inherence of Convolution for Visual Recognition}, 69 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 70 | month = {June}, 71 | year = {2021} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junya Shikishima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 25 | 26 | The main implementation of this module is based on Duo Li's "involution" and Christoph Reich's "Involution", which are subject to the same license. 27 | The original copyright notice can be found here. 28 | 29 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 30 | 31 | MIT License 32 | 33 | Copyright (c) 2021 Duo Li 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy 36 | of this software and associated documentation files (the "Software"), to deal 37 | in the Software without restriction, including without limitation the rights 38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 39 | copies of the Software, and to permit persons to whom the Software is 40 | furnished to do so, subject to the following conditions: 41 | 42 | The above copyright notice and this permission notice shall be included in all 43 | copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 51 | SOFTWARE. 52 | 53 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 54 | 55 | MIT License 56 | 57 | Copyright (c) 2021 Christoph Reich 58 | 59 | Permission is hereby granted, free of charge, to any person obtaining a copy 60 | of this software and associated documentation files (the "Software"), to deal 61 | in the Software without restriction, including without limitation the rights 62 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 63 | copies of the Software, and to permit persons to whom the Software is 64 | furnished to do so, subject to the following conditions: 65 | 66 | The above copyright notice and this permission notice shall be included in all 67 | copies or substantial portions of the Software. 68 | 69 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 70 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 71 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 72 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 73 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 74 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 75 | SOFTWARE. 76 | -------------------------------------------------------------------------------- /involution/involution2d.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.modules.utils import _pair 5 | from torch import ops 6 | 7 | def _involution2d( 8 | input: torch.Tensor, 9 | weight: torch.Tensor, 10 | kernel_size: Union[int, Tuple[int, int]] = 7, 11 | stride: Union[int, Tuple[int, int]] = 1, 12 | padding: Union[int, Tuple[int, int]] = 0, 13 | dilation: Union[int, Tuple[int, int]] = 1, 14 | groups: int = 1, 15 | bias: torch.Tensor = None, 16 | ) -> torch.Tensor: 17 | kernel_size_ = _pair(kernel_size) 18 | stride_ = _pair(stride) 19 | padding_ = _pair(padding) 20 | dilation_ = _pair(dilation) 21 | 22 | output: torch.Tensor = ops.involution.involution2d(input, weight, kernel_size_, stride_, padding_, dilation_, groups) 23 | 24 | if bias is not None: 25 | output += bias.view(1, -1, 1, 1) 26 | 27 | return output 28 | 29 | class Involution2d(nn.Module): 30 | def __init__(self, 31 | in_channels: int, 32 | out_channels: int, 33 | kernel_size: Union[int, Tuple[int, int]] = 7, 34 | stride: Union[int, Tuple[int, int]] = 1, 35 | padding: Union[int, Tuple[int, int]] = 3, 36 | dilation: Union[int, Tuple[int, int]] = 1, 37 | groups: int = 1, 38 | bias: bool = False, 39 | sigma_mapping: Optional[nn.Module] = None, 40 | reduce_ratio: int = 1, 41 | ) -> None: 42 | """2D Involution: https://arxiv.org/pdf/2103.06255.pdf 43 | Args: 44 | in_channels (int): Number of input channels 45 | out_channels (int): Number of output channels 46 | kernel_size (Union[int, Tuple[int, int]], optional): Kernel size to be used. Defaults to 7. 47 | stride (Union[int, Tuple[int, int]], optional): Stride factor to be utilized. Defaults to 1. 48 | padding (Union[int, Tuple[int, int]], optional): Padding to be used in unfold operation. Defaults to 3. 49 | dilation (Union[int, Tuple[int, int]], optional): Dilation in unfold to be employed. Defaults to 1. 50 | groups (int, optional): Number of groups to be employed. Defaults to 1. 51 | bias (bool, optional): If true bias is utilized in each convolution layer. Defaults to False. 52 | sigma_mapping (Optional[nn.Module], optional): Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized 53 | reduce_ratio (int, optional): Reduce ration of involution channels. Defaults to 1. 54 | """ 55 | super(Involution2d, self).__init__() 56 | 57 | assert isinstance(in_channels, int) and in_channels > 0, \ 58 | '"in_channels" must be a positive integer.' 59 | assert isinstance(out_channels, int) and out_channels > 0, \ 60 | '"out_channels" must be a positive integer.' 61 | assert isinstance(kernel_size, (int, tuple)), \ 62 | '"kernel_size" must be an int or a tuple of ints.' 63 | assert isinstance(stride, (int, tuple)), \ 64 | '"stride" must be an int or a tuple of ints.' 65 | assert isinstance(padding, (int, tuple)), \ 66 | '"padding" must be an int or a tuple of ints.' 67 | assert isinstance(dilation, (int, tuple)), \ 68 | '"dilation" must be an int or a tuple of ints.' 69 | assert isinstance(groups, int) and groups > 0, \ 70 | '"groups" must be a positive integer.' 71 | assert in_channels % groups == 0, '"in_channels" must be divisible by "groups".' 72 | assert out_channels % groups == 0, '"out_channels" must be divisible by "groups".' 73 | assert isinstance(bias, bool), '"bias" must be a bool.' 74 | assert isinstance(sigma_mapping, nn.Module) or sigma_mapping is None, \ 75 | '"sigma_mapping" muse be an int or a tuple of ints.' 76 | assert isinstance(reduce_ratio, int) and reduce_ratio > 0, \ 77 | '"reduce_ratio" must be a positive integer.' 78 | 79 | self.in_channels: int = in_channels 80 | self.out_channels: int = out_channels 81 | self.kernel_size: Tuple[int, int] = _pair(kernel_size) 82 | self.stride: Tuple[int, int] = _pair(stride) 83 | self.padding: Tuple[int, int] = _pair(padding) 84 | self.dilation: Tuple[int, int] = _pair(dilation) 85 | self.groups: int = groups 86 | self.bias: bool = bias 87 | self.reduce_ratio: int = reduce_ratio 88 | 89 | self.sigma_mapping = sigma_mapping if isinstance(sigma_mapping, nn.Module) else nn.Sequential( 90 | nn.BatchNorm2d(num_features=self.out_channels // 91 | self.reduce_ratio, momentum=0.3), 92 | nn.ReLU() 93 | ) 94 | self.initial_mapping = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, bias=bias) \ 95 | if self.in_channels != self.out_channels else nn.Identity() 96 | self.o_mapping = nn.AvgPool2d( 97 | kernel_size=self.stride) if self.stride[0] > 1 or self.stride[1] > 1 else nn.Identity() 98 | self.reduce_mapping = nn.Conv2d( 99 | in_channels=self.in_channels, out_channels=self.out_channels // self.reduce_ratio, kernel_size=1, bias=bias) 100 | self.span_mapping = nn.Conv2d(in_channels=self.out_channels // self.reduce_ratio, 101 | out_channels=self.kernel_size[0] * self.kernel_size[1] * self.groups, kernel_size=1, bias=bias) 102 | 103 | def __repr__(self) -> str: 104 | """Method returns information about the module 105 | Returns: 106 | str: Info string 107 | """ 108 | return (f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, kernel_size=({self.kernel_size[0]}, {self.kernel_size[1]}), ' 109 | f'stride=({self.stride[0]}, {self.stride[1]}), padding=({self.padding[0]}, {self.padding[1]}), dilation=({self.dilation[0], self.dilation[1]}), ' 110 | f'groups={self.groups}, bias={self.bias}, reduce_ratio={self.reduce_ratio}, sigma_mapping={str(self.sigma_mapping)}' 111 | ) 112 | 113 | def forward(self, input: torch.Tensor) -> torch.Tensor: 114 | """Forward pass 115 | Args: 116 | input (torch.Tensor): Input tensor of the shape [batch size, in channels, height, width] 117 | Returns: 118 | torch.Tensor: Output tensor of the shape [batch size, out channels, height, width] (w/ same padding) 119 | """ 120 | weight: torch.Tensor = self.span_mapping(self.sigma_mapping(self.reduce_mapping(self.o_mapping(input)))) 121 | input_init: torch.Tensor = self.initial_mapping(input) 122 | 123 | return _involution2d(input_init, weight, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 124 | -------------------------------------------------------------------------------- /include/involution2d_wrapper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "involution2d_cpu.h" 8 | 9 | #ifdef USE_CUDA 10 | # include "involution2d_cuda.cuh" 11 | #endif 12 | 13 | namespace involution { 14 | 15 | at::Tensor involution2d( 16 | const at::Tensor& input, 17 | const at::Tensor& weight, 18 | const std::vector& stride, 19 | const std::vector& padding, 20 | const std::vector& dilation 21 | ) { 22 | static auto op = at::Dispatcher::singleton() 23 | .findSchemaOrThrow("involution::involution2d", "") 24 | .typed(); 25 | 26 | return op.call(input, weight, stride, padding, dilation); 27 | } 28 | 29 | at::Tensor involution2d_autocast( 30 | const at::Tensor& input, 31 | const at::Tensor& weight, 32 | const std::vector& stride, 33 | const std::vector& padding, 34 | const std::vector& dilation 35 | ) { 36 | c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); 37 | auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); 38 | return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation) 39 | .to(input.scalar_type()); 40 | } 41 | 42 | at::Tensor _involution2d_backward_grad_input( 43 | const at::Tensor& grad, 44 | const at::Tensor& weight, 45 | const std::vector& input_shape, 46 | const std::vector& stride, 47 | const std::vector& padding, 48 | const std::vector& dilation 49 | ) { 50 | static auto op = at::Dispatcher::singleton() 51 | .findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "") 52 | .typed(); 53 | 54 | return op.call(grad, weight, input_shape, stride, padding, dilation); 55 | } 56 | 57 | at::Tensor _involution2d_backward_grad_weight( 58 | const at::Tensor& grad, 59 | const at::Tensor& input, 60 | const std::vector& weight_shape, 61 | const std::vector& stride, 62 | const std::vector& padding, 63 | const std::vector& dilation 64 | ) { 65 | static auto op = at::Dispatcher::singleton() 66 | .findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "") 67 | .typed(); 68 | 69 | return op.call(grad, input, weight_shape, stride, padding, dilation); 70 | } 71 | 72 | namespace cpu { 73 | 74 | class Involution2dFunctionCPU : public torch::autograd::Function 75 | { 76 | public: 77 | 78 | static torch::autograd::variable_list forward( 79 | torch::autograd::AutogradContext* ctx, 80 | const torch::autograd::Variable& input, 81 | const torch::autograd::Variable& weight, 82 | const std::vector& kernel_size, 83 | const std::vector& stride, 84 | const std::vector& padding, 85 | const std::vector& dilation, 86 | const int64_t groups 87 | ) { 88 | ctx->saved_data["kernel_size"] = kernel_size; 89 | ctx->saved_data["stride"] = stride; 90 | ctx->saved_data["padding"] = padding; 91 | ctx->saved_data["dilation"] = dilation; 92 | ctx->saved_data["groups"] = groups; 93 | ctx->save_for_backward({input, weight}); 94 | 95 | auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); 96 | 97 | return {output}; 98 | } 99 | 100 | static torch::autograd::variable_list backward( 101 | torch::autograd::AutogradContext* ctx, 102 | const torch::autograd::variable_list grad_output 103 | ) { 104 | torch::autograd::variable_list saved = ctx->get_saved_variables(); 105 | torch::autograd::Variable input = saved[0]; 106 | torch::autograd::Variable weight = saved[1]; 107 | 108 | auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); 109 | auto stride = ctx->saved_data["stride"].toIntVector(); 110 | auto padding = ctx->saved_data["padding"].toIntVector(); 111 | auto dilation = ctx->saved_data["dilation"].toIntVector(); 112 | auto groups = ctx->saved_data["groups"].toInt(); 113 | 114 | auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); 115 | 116 | return { 117 | grads[0], 118 | grads[1], 119 | torch::autograd::Variable(), 120 | torch::autograd::Variable(), 121 | torch::autograd::Variable(), 122 | torch::autograd::Variable(), 123 | torch::autograd::Variable() 124 | }; 125 | } 126 | }; 127 | 128 | at::Tensor involution2d_autograd( 129 | const torch::autograd::Variable& input, 130 | const torch::autograd::Variable& weight, 131 | const std::vector& kernel_size, 132 | const std::vector& stride, 133 | const std::vector& padding, 134 | const std::vector& dilation, 135 | const int64_t groups 136 | ) { 137 | return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; 138 | } 139 | 140 | } // namespace cpu 141 | 142 | #ifdef USE_CUDA 143 | 144 | namespace cuda { 145 | 146 | class Involution2dFunctionCUDA : public torch::autograd::Function 147 | { 148 | public: 149 | 150 | static torch::autograd::variable_list forward( 151 | torch::autograd::AutogradContext* ctx, 152 | const torch::autograd::Variable& input, 153 | const torch::autograd::Variable& weight, 154 | const std::vector& kernel_size, 155 | const std::vector& stride, 156 | const std::vector& padding, 157 | const std::vector& dilation, 158 | const int64_t groups 159 | ) { 160 | ctx->saved_data["kernel_size"] = kernel_size; 161 | ctx->saved_data["stride"] = stride; 162 | ctx->saved_data["padding"] = padding; 163 | ctx->saved_data["dilation"] = dilation; 164 | ctx->saved_data["groups"] = groups; 165 | ctx->save_for_backward({input, weight}); 166 | 167 | auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); 168 | 169 | return {output}; 170 | } 171 | 172 | static torch::autograd::variable_list backward( 173 | torch::autograd::AutogradContext* ctx, 174 | const torch::autograd::variable_list grad_output 175 | ) { 176 | torch::autograd::variable_list saved = ctx->get_saved_variables(); 177 | torch::autograd::Variable input = saved[0]; 178 | torch::autograd::Variable weight = saved[1]; 179 | 180 | auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); 181 | auto stride = ctx->saved_data["stride"].toIntVector(); 182 | auto padding = ctx->saved_data["padding"].toIntVector(); 183 | auto dilation = ctx->saved_data["dilation"].toIntVector(); 184 | auto groups = ctx->saved_data["groups"].toInt(); 185 | 186 | auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); 187 | 188 | return { 189 | grads[0], 190 | grads[1], 191 | torch::autograd::Variable(), 192 | torch::autograd::Variable(), 193 | torch::autograd::Variable(), 194 | torch::autograd::Variable(), 195 | torch::autograd::Variable() 196 | }; 197 | } 198 | }; 199 | 200 | at::Tensor involution2d_autograd( 201 | const torch::autograd::Variable& input, 202 | const torch::autograd::Variable& weight, 203 | const std::vector& kernel_size, 204 | const std::vector& stride, 205 | const std::vector& padding, 206 | const std::vector& dilation, 207 | const int64_t groups 208 | ) { 209 | return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; 210 | } 211 | 212 | at::Tensor involution2d_autocast( 213 | const torch::autograd::Variable& input, 214 | const torch::autograd::Variable& weight, 215 | const std::vector& kernel_size, 216 | const std::vector& stride, 217 | const std::vector& padding, 218 | const std::vector& dilation, 219 | const int64_t groups 220 | ) { 221 | c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); 222 | auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); 223 | return involution2d_autograd( 224 | at::autocast::cached_cast(exec_type, input), 225 | at::autocast::cached_cast(exec_type, weight), 226 | kernel_size, stride, padding, dilation, groups 227 | ); 228 | } 229 | 230 | } // namespace cuda 231 | 232 | #endif 233 | 234 | } // namespace involution 235 | -------------------------------------------------------------------------------- /src/involution2d_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "involution2d_cpu.h" 2 | 3 | namespace involution { 4 | namespace cpu { 5 | 6 | template 7 | static void involution2d_forward_frame( 8 | const at::Tensor& in_data, 9 | const at::Tensor& weight_data, 10 | at::Tensor& out_data, 11 | const at::IntArrayRef& kernel_size, 12 | const at::IntArrayRef& padding, 13 | const at::IntArrayRef& stride, 14 | const at::IntArrayRef& dilation 15 | ) { 16 | auto num_elements = out_data.numel(); 17 | const auto groups = weight_data.size(1); 18 | const auto channels = in_data.size(1); 19 | const auto in_height = in_data.size(2); 20 | const auto in_width = in_data.size(3); 21 | const auto out_height = out_data.size(2); 22 | const auto out_width = out_data.size(3); 23 | 24 | auto in_data_a = in_data.accessor(); 25 | auto weight_data_a = weight_data.accessor(); 26 | auto* out_data_p = out_data.data_ptr(); 27 | 28 | #pragma omp parallel for 29 | for (int64_t idx = 0l; idx < num_elements; idx++) { 30 | const int64_t w = idx % out_width; 31 | const int64_t h = (idx / out_width) % out_height; 32 | int64_t divisor = out_width * out_height; 33 | const int64_t c = (idx / divisor) % channels; 34 | divisor *= channels; 35 | const int64_t n = idx / divisor; 36 | const int64_t g = c / (channels / groups); 37 | 38 | scalar_t value = 0; 39 | 40 | for (int64_t kh = 0l; kh < kernel_size[0]; kh++) { 41 | const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0]; 42 | 43 | if ((0l <= h_in) && (h_in < in_height)) { 44 | for (int64_t kw = 0l; kw < kernel_size[1]; kw++) { 45 | const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1]; 46 | 47 | if ((0l <= w_in) && (w_in < in_width)) { 48 | value += weight_data_a[n][g][kh][kw][h][w] * in_data_a[n][c][h_in][w_in]; 49 | } 50 | } 51 | } 52 | } 53 | out_data_p[idx] = value; 54 | } 55 | } 56 | 57 | at::Tensor involution2d_forward( 58 | const at::Tensor& input, 59 | const at::Tensor& weight, 60 | const std::vector& kernel_size, 61 | const std::vector& stride, 62 | const std::vector& padding, 63 | const std::vector& dilation, 64 | const int64_t groups 65 | ) { 66 | AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor."); 67 | AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor."); 68 | 69 | at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2}; 70 | 71 | at::CheckedFrom c = __func__; 72 | at::checkAllSameType(c, {input_t, weight_t}); 73 | 74 | const auto batch_size = input.size(0); 75 | const auto channels = input.size(1); 76 | const auto in_height = input.size(2); 77 | const auto in_width = input.size(3); 78 | 79 | const auto weight_height = weight.size(2); 80 | const auto weight_width = weight.size(3); 81 | 82 | const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); 83 | 84 | const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1; 85 | const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1; 86 | 87 | at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options()); 88 | 89 | if (output.numel() == 0) { 90 | return output; 91 | } 92 | 93 | AT_DISPATCH_FLOATING_TYPES_AND2( 94 | at::kHalf, 95 | at::kBFloat16, 96 | input.scalar_type(), 97 | "involution2d_forward_kernel", [&] { 98 | involution2d_forward_frame( 99 | input, 100 | weight_, 101 | output, 102 | kernel_size, 103 | padding, 104 | stride, 105 | dilation 106 | ); 107 | } 108 | ); 109 | return output; 110 | } 111 | 112 | template 113 | static void involution2d_backward_grad_input_frame( 114 | const at::Tensor& out_diff, 115 | const at::Tensor& weight_data, 116 | at::Tensor& in_diff, 117 | const at::IntArrayRef& kernel_size, 118 | const at::IntArrayRef& padding, 119 | const at::IntArrayRef& stride, 120 | const at::IntArrayRef& dilation 121 | ) { 122 | auto num_elements = in_diff.numel(); 123 | const auto groups = weight_data.size(1); 124 | const auto channels = in_diff.size(1); 125 | const auto in_height = in_diff.size(2); 126 | const auto in_width = in_diff.size(3); 127 | const auto out_height = out_diff.size(2); 128 | const auto out_width = out_diff.size(3); 129 | 130 | auto out_diff_a = out_diff.accessor(); 131 | auto weight_data_a = weight_data.accessor(); 132 | auto* in_diff_p = in_diff.data_ptr(); 133 | 134 | #pragma omp parallel for 135 | for (int64_t idx = 0l; idx < num_elements; idx++) { 136 | const int64_t w = idx % in_width; 137 | const int64_t h = (idx / in_width) % in_height; 138 | int64_t divisor = in_width * in_height; 139 | const int64_t c = (idx / divisor) % channels; 140 | divisor *= channels; 141 | const int64_t n = idx / divisor; 142 | const int64_t g = c / (channels / groups); 143 | 144 | scalar_t value = 0; 145 | 146 | for (int64_t kh = 0l; kh < kernel_size[0]; kh++) { 147 | const int64_t h_out_s = h + padding[0] - kh * dilation[0]; 148 | 149 | for (int64_t kw = 0l; kw < kernel_size[1]; kw++) { 150 | const int64_t w_out_s = w + padding[1] - kw * dilation[1]; 151 | 152 | if (((h_out_s % stride[0]) == 0) && ((w_out_s % stride[1]) == 0)) { 153 | const int64_t h_out = h_out_s / stride[0]; 154 | const int64_t w_out = h_out_s / stride[1]; 155 | 156 | if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) { 157 | value += weight_data_a[n][g][kh][kw][h_out][w_out] * out_diff_a[n][c][h_out][w_out]; 158 | } 159 | } 160 | } 161 | } 162 | in_diff_p[idx] = value; 163 | } 164 | } 165 | 166 | at::Tensor involution2d_backward_grad_input( 167 | const at::Tensor& grad, 168 | const at::Tensor& weight, 169 | const std::vector& input_shape, 170 | const std::vector& kernel_size, 171 | const std::vector& stride, 172 | const std::vector& padding, 173 | const std::vector& dilation, 174 | const int64_t groups 175 | ) { 176 | AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor."); 177 | AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor."); 178 | 179 | at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2}; 180 | 181 | at::CheckedFrom c = __func__; 182 | at::checkAllSameType(c, {grad_t, weight_t}); 183 | 184 | const auto batch_size = input_shape[0]; 185 | 186 | const auto weight_height = weight.size(2); 187 | const auto weight_width = weight.size(3); 188 | 189 | const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); 190 | 191 | at::Tensor grad_input = at::zeros(input_shape, grad.options()); 192 | 193 | if (grad_input.numel() == 0) { 194 | return grad_input; 195 | } 196 | 197 | AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, grad.scalar_type(), "involution2d_backward_grad_input_frame", [&] { 198 | involution2d_backward_grad_input_frame( 199 | grad, 200 | weight_, 201 | grad_input, 202 | kernel_size, 203 | padding, 204 | stride, 205 | dilation 206 | ); 207 | }); 208 | 209 | return grad_input; 210 | } 211 | 212 | template 213 | static void involution2d_backward_grad_weight_frame( 214 | const at::Tensor& out_diff, 215 | const at::Tensor& in_data, 216 | at::Tensor& weight_diff, 217 | const at::IntArrayRef& kernel_size, 218 | const at::IntArrayRef& padding, 219 | const at::IntArrayRef& stride, 220 | const at::IntArrayRef& dilation 221 | ) { 222 | auto num_elements = weight_diff.numel(); 223 | const auto groups = weight_diff.size(1); 224 | const auto batch_size = in_data.size(0); 225 | const auto channels = in_data.size(1); 226 | const auto in_height = in_data.size(2); 227 | const auto in_width = in_data.size(3); 228 | const auto out_height = out_diff.size(2); 229 | const auto out_width = out_diff.size(3); 230 | const auto channels_per_group = channels / groups; 231 | 232 | auto out_diff_a = out_diff.accessor(); 233 | auto in_data_a = in_data.accessor(); 234 | auto* weight_diff_p = weight_diff.data_ptr(); 235 | 236 | #pragma omp parallel for 237 | for (int64_t idx = 0l; idx < num_elements; idx++) { 238 | const int64_t w = idx % out_width; 239 | const int64_t h = (idx / out_width) % out_height; 240 | int64_t divisor = out_width * out_height; 241 | const int64_t kw = (idx / divisor) % kernel_size[1]; 242 | divisor *= kernel_size[1]; 243 | const int64_t kh = (idx / divisor) % kernel_size[0]; 244 | 245 | const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0]; 246 | const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1]; 247 | 248 | if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) { 249 | divisor *= kernel_size[0]; 250 | const int64_t g = (idx / divisor) % groups; 251 | divisor *= groups; 252 | const int64_t n = (idx / divisor) % batch_size; 253 | 254 | scalar_t value = 0; 255 | 256 | for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) { 257 | value += out_diff_a[n][c][h][w] * in_data_a[n][c][h_in][w_in]; 258 | } 259 | weight_diff_p[idx] = value; 260 | } 261 | else { 262 | weight_diff_p[idx] = 0; 263 | } 264 | } 265 | } 266 | 267 | at::Tensor involution2d_backward_grad_weight( 268 | const at::Tensor& grad, 269 | const at::Tensor& input, 270 | const std::vector& weight_shape, 271 | const std::vector& kernel_size, 272 | const std::vector& stride, 273 | const std::vector& padding, 274 | const std::vector& dilation, 275 | const int64_t groups 276 | ) { 277 | AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor."); 278 | AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor."); 279 | 280 | at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2}; 281 | 282 | at::CheckedFrom c = __func__; 283 | at::checkAllSameType(c, {grad_t, input_t}); 284 | 285 | const auto batch_size = input.size(0); 286 | 287 | at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options()); 288 | 289 | if (grad_weight.numel() == 0) { 290 | return grad_weight.view(weight_shape); 291 | } 292 | 293 | AT_DISPATCH_FLOATING_TYPES_AND2( 294 | at::kHalf, 295 | at::kBFloat16, 296 | grad.scalar_type(), 297 | "involution2d_backward_grad_weight_kernel", [&] { 298 | involution2d_backward_grad_weight_frame( 299 | grad, 300 | input, 301 | grad_weight, 302 | kernel_size, 303 | padding, 304 | stride, 305 | dilation 306 | ); 307 | } 308 | ); 309 | return grad_weight.view(weight_shape); 310 | } 311 | 312 | std::vector involution2d_backward( 313 | const at::Tensor& grad, 314 | const at::Tensor& weight, 315 | const at::Tensor& input, 316 | const std::vector& kernel_size, 317 | const std::vector& stride, 318 | const std::vector& padding, 319 | const std::vector& dilation, 320 | const int64_t groups 321 | ) { 322 | auto input_sizes = input.sizes(); 323 | std::vector input_size; 324 | std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size)); 325 | 326 | auto grad_input = involution2d_backward_grad_input( 327 | grad, 328 | weight, 329 | input_size, 330 | kernel_size, 331 | stride, 332 | padding, 333 | dilation, 334 | groups 335 | ); 336 | 337 | auto weight_sizes = weight.sizes(); 338 | std::vector weight_size; 339 | std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size)); 340 | 341 | auto grad_weight = involution2d_backward_grad_weight( 342 | grad, 343 | input, 344 | weight_size, 345 | kernel_size, 346 | stride, 347 | padding, 348 | dilation, 349 | groups 350 | ); 351 | 352 | // std::vector output{grad_input, grad_weight}; 353 | 354 | // return output; 355 | return {grad_input, grad_weight}; 356 | } 357 | 358 | } // namespace cpu 359 | } // namespace involution 360 | -------------------------------------------------------------------------------- /src/involution2d_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace involution { 4 | namespace cuda { 5 | 6 | static u_int32_t ceildiv(u_int32_t num_elements, u_int32_t threads) { 7 | return (num_elements + threads - 1) / threads; 8 | } 9 | 10 | template 11 | C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) 12 | __global__ static void involution2d_forward_kernel( 13 | const at::GenericPackedTensorAccessor in_data, 14 | const at::GenericPackedTensorAccessor weight_data, 15 | scalar_t* const out_data, 16 | const int64_t num_elements, 17 | const int64_t channels, 18 | const int64_t groups, 19 | const int64_t in_height, const int64_t in_width, 20 | const int64_t out_height, const int64_t out_width, 21 | const int64_t kernel_height, const int64_t kernel_width, 22 | const int64_t pad_h, const int64_t pad_w, 23 | const int64_t stride_h, const int64_t stride_w, 24 | const int64_t dilation_h, const int64_t dilation_w 25 | ) { 26 | CUDA_KERNEL_LOOP(idx, num_elements) { 27 | const int64_t w = idx % out_width; 28 | const int64_t h = (idx / out_width) % out_height; 29 | int64_t divisor = out_width * out_height; 30 | const int64_t c = (idx / divisor) % channels; 31 | divisor *= channels; 32 | const int64_t n = idx / divisor; 33 | const int64_t g = c / (channels / groups); 34 | 35 | scalar_t value = 0; 36 | 37 | for (int64_t kh = 0l; kh < kernel_height; kh++) { 38 | const int64_t h_in = h * stride_h + kh * dilation_h - pad_h; 39 | 40 | if ((0l <= h_in) && (h_in < in_height)) { 41 | for (int64_t kw = 0l; kw < kernel_width; kw++) { 42 | const int64_t w_in = w * stride_w + kw * dilation_w - pad_w; 43 | 44 | if ((0l <= w_in) && (w_in < in_width)) { 45 | value += weight_data[n][g][kh][kw][h][w] * in_data[n][c][h_in][w_in]; 46 | } 47 | } 48 | } 49 | } 50 | 51 | out_data[idx] = value; 52 | } 53 | } 54 | 55 | at::Tensor involution2d_forward( 56 | const at::Tensor& input, 57 | const at::Tensor& weight, 58 | const std::vector& kernel_size, 59 | const std::vector& stride, 60 | const std::vector& padding, 61 | const std::vector& dilation, 62 | const int64_t groups 63 | ) { 64 | AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor."); 65 | AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor."); 66 | 67 | at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2}; 68 | 69 | at::CheckedFrom c = __func__; 70 | at::checkAllSameGPU(c, {input_t, weight_t}); 71 | at::checkAllSameType(c, {input_t, weight_t}); 72 | 73 | at::cuda::CUDAGuard device_guard(input.device()); 74 | 75 | const auto batch_size = input.size(0); 76 | const auto channels = input.size(1); 77 | const auto in_height = input.size(2); 78 | const auto in_width = input.size(3); 79 | 80 | const auto weight_height = weight.size(2); 81 | const auto weight_width = weight.size(3); 82 | 83 | const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); 84 | 85 | const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1; 86 | const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1; 87 | 88 | at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options()); 89 | const auto num_elements = output.numel(); 90 | 91 | if (num_elements == 0) { 92 | AT_CUDA_CHECK(cudaGetLastError()); 93 | return output; 94 | } 95 | 96 | const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); 97 | const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); 98 | const dim3 threads_per_block(threads, 1u, 1u); 99 | 100 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 101 | 102 | AT_DISPATCH_FLOATING_TYPES_AND2( 103 | at::kHalf, 104 | at::kBFloat16, 105 | input.scalar_type(), 106 | "involution2d_forward_kernel", [&] { 107 | involution2d_forward_kernel<<>>( 108 | input.generic_packed_accessor(), 109 | weight_.generic_packed_accessor(), 110 | output.data_ptr(), 111 | num_elements, 112 | channels, 113 | groups, 114 | in_height, in_width, 115 | out_height, out_width, 116 | kernel_size[0], kernel_size[1], 117 | padding[0], padding[1], 118 | stride[0], stride[1], 119 | dilation[0], dilation[1] 120 | ); 121 | } 122 | ); 123 | AT_CUDA_CHECK(cudaGetLastError()); 124 | return output; 125 | } 126 | 127 | template 128 | C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) 129 | __global__ static void involution2d_backward_grad_input_kernel( 130 | const at::GenericPackedTensorAccessor out_diff, 131 | const at::GenericPackedTensorAccessor weight_data, 132 | scalar_t* const in_diff, 133 | const int64_t num_elements, 134 | const int64_t channels, 135 | const int64_t groups, 136 | const int64_t in_height, const int64_t in_width, 137 | const int64_t out_height, const int64_t out_width, 138 | const int64_t kernel_height, const int64_t kernel_width, 139 | const int64_t pad_h, const int64_t pad_w, 140 | const int64_t stride_h, const int64_t stride_w, 141 | const int64_t dilation_h, const int64_t dilation_w 142 | ) { 143 | CUDA_KERNEL_LOOP(idx, num_elements) { 144 | const int64_t w = idx % in_width; 145 | const int64_t h = (idx / in_width) % in_height; 146 | int64_t divisor = in_width * in_height; 147 | const int64_t c = (idx / divisor) % channels; 148 | divisor *= channels; 149 | const int64_t n = idx / divisor; 150 | const int64_t g = c / (channels / groups); 151 | 152 | scalar_t value = 0; 153 | 154 | for (int64_t kh = 0l; kh < kernel_height; kh++) { 155 | const int64_t h_out_s = h + pad_h - kh * dilation_h; 156 | 157 | for (int64_t kw = 0l; kw < kernel_width; kw++) { 158 | const int64_t w_out_s = w + pad_w - kw * dilation_w; 159 | 160 | if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { 161 | const int64_t h_out = h_out_s / stride_h; 162 | const int64_t w_out = h_out_s / stride_w; 163 | 164 | if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) { 165 | value += weight_data[n][g][kh][kw][h_out][w_out] * out_diff[n][c][h_out][w_out]; 166 | } 167 | } 168 | } 169 | } 170 | in_diff[idx] = value; 171 | } 172 | } 173 | 174 | at::Tensor involution2d_backward_grad_input( 175 | const at::Tensor& grad, 176 | const at::Tensor& weight, 177 | const std::vector& input_shape, 178 | const std::vector& kernel_size, 179 | const std::vector& stride, 180 | const std::vector& padding, 181 | const std::vector& dilation, 182 | const int64_t groups 183 | ) { 184 | AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor."); 185 | AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor."); 186 | 187 | at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2}; 188 | 189 | at::CheckedFrom c = __func__; 190 | at::checkAllSameGPU(c, {grad_t, weight_t}); 191 | at::checkAllSameType(c, {grad_t, weight_t}); 192 | 193 | at::cuda::CUDAGuard device_guard(grad.device()); 194 | 195 | const auto batch_size = input_shape[0]; 196 | const auto channels = input_shape[1]; 197 | const auto in_height = input_shape[2]; 198 | const auto in_width = input_shape[3]; 199 | 200 | const auto weight_height = weight.size(2); 201 | const auto weight_width = weight.size(3); 202 | 203 | const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); 204 | 205 | const auto out_height = grad.size(2); 206 | const auto out_width = grad.size(3); 207 | 208 | at::Tensor grad_input = at::zeros(input_shape, grad.options()); 209 | const auto num_elements = grad_input.numel(); 210 | 211 | if (num_elements == 0) { 212 | AT_CUDA_CHECK(cudaGetLastError()); 213 | return grad_input; 214 | } 215 | 216 | const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); 217 | const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); 218 | const dim3 threads_per_block(threads, 1u, 1u); 219 | 220 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 221 | 222 | AT_DISPATCH_FLOATING_TYPES_AND2( 223 | at::kHalf, 224 | at::kBFloat16, 225 | grad.scalar_type(), 226 | "involution2d_backward_grad_input_kernel", [&] { 227 | involution2d_backward_grad_input_kernel<<>>( 228 | grad.generic_packed_accessor(), 229 | weight_.generic_packed_accessor(), 230 | grad_input.data_ptr(), 231 | num_elements, 232 | channels, 233 | groups, 234 | in_height, in_width, 235 | out_height, out_width, 236 | kernel_size[0], kernel_size[1], 237 | padding[0], padding[1], 238 | stride[0], stride[1], 239 | dilation[0], dilation[1] 240 | ); 241 | } 242 | ); 243 | AT_CUDA_CHECK(cudaGetLastError()); 244 | return grad_input; 245 | } 246 | 247 | template 248 | C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) 249 | __global__ static void involution2d_backward_grad_weight_kernel( 250 | const at::GenericPackedTensorAccessor out_diff, 251 | const at::GenericPackedTensorAccessor in_data, 252 | scalar_t* const weight_diff, 253 | const int64_t num_elements, 254 | const int64_t batch_size, 255 | const int64_t channels_per_group, 256 | const int64_t groups, 257 | const int64_t in_height, const int64_t in_width, 258 | const int64_t out_height, const int64_t out_width, 259 | const int64_t kernel_height, const int64_t kernel_width, 260 | const int64_t pad_h, const int64_t pad_w, 261 | const int64_t stride_h, const int64_t stride_w, 262 | const int64_t dilation_h, const int64_t dilation_w 263 | ) { 264 | CUDA_KERNEL_LOOP(idx, num_elements) { 265 | const int64_t w = idx % out_width; 266 | const int64_t h = (idx / out_width) % out_height; 267 | int64_t divisor = out_width * out_height; 268 | const int64_t kw = (idx / divisor) % kernel_width; 269 | divisor *= kernel_width; 270 | const int64_t kh = (idx / divisor) % kernel_height; 271 | 272 | const int64_t h_in = -pad_h + h * stride_h + kh * dilation_h; 273 | const int64_t w_in = -pad_w + w * stride_w + kw * dilation_w; 274 | 275 | if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) { 276 | divisor *= kernel_height; 277 | const int64_t g = (idx / divisor) % groups; 278 | divisor *= groups; 279 | const int64_t n = (idx / divisor) % batch_size; 280 | 281 | scalar_t value = 0; 282 | 283 | for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) { 284 | value += out_diff[n][c][h][w] * in_data[n][c][h_in][w_in]; 285 | } 286 | weight_diff[idx] = value; 287 | } 288 | else { 289 | weight_diff[idx] = 0; 290 | } 291 | } 292 | } 293 | 294 | at::Tensor involution2d_backward_grad_weight( 295 | const at::Tensor& grad, 296 | const at::Tensor& input, 297 | const std::vector& weight_shape, 298 | const std::vector& kernel_size, 299 | const std::vector& stride, 300 | const std::vector& padding, 301 | const std::vector& dilation, 302 | const int64_t groups 303 | ) { 304 | AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor."); 305 | AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor."); 306 | 307 | at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2}; 308 | 309 | at::CheckedFrom c = __func__; 310 | at::checkAllSameGPU(c, {grad_t, input_t}); 311 | at::checkAllSameType(c, {grad_t, input_t}); 312 | 313 | at::cuda::CUDAGuard device_guard(grad.device()); 314 | 315 | const auto batch_size = input.size(0); 316 | const auto channels = input.size(1); 317 | const auto in_height = input.size(2); 318 | const auto in_width = input.size(3); 319 | 320 | const auto out_height = grad.size(2); 321 | const auto out_width = grad.size(3); 322 | 323 | at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options()); 324 | const auto num_elements = grad_weight.numel(); 325 | 326 | if (num_elements == 0) { 327 | AT_CUDA_CHECK(cudaGetLastError()); 328 | return grad_weight.view(weight_shape); 329 | } 330 | 331 | const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); 332 | const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); 333 | const dim3 threads_per_block(threads, 1u, 1u); 334 | 335 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 336 | 337 | AT_DISPATCH_FLOATING_TYPES_AND2( 338 | at::kHalf, 339 | at::kBFloat16, 340 | grad.scalar_type(), 341 | "involution2d_backward_grad_weight_kernel", [&] { 342 | involution2d_backward_grad_weight_kernel<<>>( 343 | grad.generic_packed_accessor(), 344 | input.generic_packed_accessor(), 345 | grad_weight.data_ptr(), 346 | num_elements, 347 | batch_size, 348 | channels / groups, 349 | groups, 350 | in_height, in_width, 351 | out_height, out_width, 352 | kernel_size[0], kernel_size[1], 353 | padding[0], padding[1], 354 | stride[0], stride[1], 355 | dilation[0], dilation[1] 356 | ); 357 | } 358 | ); 359 | AT_CUDA_CHECK(cudaGetLastError()); 360 | return grad_weight.view(weight_shape); 361 | } 362 | 363 | std::vector involution2d_backward( 364 | const at::Tensor& grad, 365 | const at::Tensor& weight, 366 | const at::Tensor& input, 367 | const std::vector& kernel_size, 368 | const std::vector& stride, 369 | const std::vector& padding, 370 | const std::vector& dilation, 371 | const int64_t groups 372 | ) { 373 | auto input_sizes = input.sizes(); 374 | std::vector input_size; 375 | std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size)); 376 | 377 | auto grad_input = involution2d_backward_grad_input( 378 | grad, 379 | weight, 380 | input_size, 381 | kernel_size, 382 | stride, 383 | padding, 384 | dilation, 385 | groups 386 | ); 387 | 388 | auto weight_sizes = weight.sizes(); 389 | std::vector weight_size; 390 | std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size)); 391 | 392 | auto grad_weight = involution2d_backward_grad_weight( 393 | grad, 394 | input, 395 | weight_size, 396 | kernel_size, 397 | stride, 398 | padding, 399 | dilation, 400 | groups 401 | ); 402 | 403 | return {grad_input, grad_weight}; 404 | } 405 | 406 | } // namespace cuda 407 | } // namespace involution 408 | --------------------------------------------------------------------------------