├── .gitignore ├── LICENSE ├── README.md ├── examples.py ├── involution ├── __init__.py └── involution.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/artifacts 34 | # .idea/compiler.xml 35 | # .idea/jarRepositories.xml 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Christoph Reich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Involution: Inverting the Inherence of Convolution for Visual Recognition 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/ChristophReich1996/Involution/blob/master/LICENSE) 4 | 5 | Unofficial **PyTorch** reimplementation of the paper [Involution: Inverting the Inherence of Convolution for Visual Recognition](https://arxiv.org/pdf/2103.06255.pdf) 6 | by Duo Li, Jie Hu, Changhu Wang et al. published at CVPR 2021. 7 | 8 | **This repository includes a pure PyTorch implementation of a 2D and 3D involution.** 9 | 10 | Please note that the [official implementation](https://github.com/d-li14/involution) provides a more memory efficient 11 | CuPy implementation of the 2D involution. Additionally, [shikishima-TasakiLab](https://github.com/shikishima-TasakiLab) provides a fast and memory efficent [CUDA implementation](https://github.com/shikishima-TasakiLab/Involution-PyTorch) of the 2D Involution. 12 | 13 | ## Installation 14 | The 2D and 3D involution can be easily installed by using `pip`. 15 | ````shell script 16 | pip install git+https://github.com/ChristophReich1996/Involution 17 | ```` 18 | 19 | ## Example Usage 20 | Additional examples, such as strided involutions or transposed convolution like involutions, can be found in the 21 | [example.py](examples.py) file. 22 | 23 | The 2D involution can be used as a `nn.Module` as follows: 24 | ````python 25 | import torch 26 | from involution import Involution2d 27 | 28 | involution = Involution2d(in_channels=32, out_channels=64) 29 | output = involution(torch.rand(1, 32, 128, 128)) 30 | ```` 31 | 32 | The 2D involution takes the following parameters. 33 | 34 | | Parameter | Description | Type | 35 | | ------------- | ------------- | ------------- | 36 | | in_channels | Number of input channels | int | 37 | | out_channels | Number of output channels | int | 38 | | sigma_mapping | Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized (default=None) | Optional[nn.Module] | 39 | | kernel_size | Kernel size to be used (default=(7, 7)) | Union[int, Tuple[int, int]] | 40 | | stride | Stride factor to be utilized (default=(1, 1)) | Union[int, Tuple[int, int]] | 41 | | groups | Number of groups to be employed (default=1) | int | 42 | | reduce_ratio | Reduce ration of involution channels (default=1) | int | 43 | | dilation | Dilation in unfold to be employed (default=(1, 1)) | Union[int, Tuple[int, int]] | 44 | | padding | Padding to be used in unfold operation (default=(3, 3)) | Union[int, Tuple[int, int]] | 45 | | bias | If true bias is utilized in each convolution layer (default=False) | bool | 46 | | force_shape_match | If true potential shape mismatch is solved by performing avg pool (default=False) | bool | 47 | | **kwargs | Unused additional key word arguments | Any | 48 | 49 | The 3D involution can be used as a `nn.Module` as follows: 50 | ````python 51 | import torch 52 | from involution import Involution3d 53 | 54 | involution = Involution3d(in_channels=8, out_channels=16) 55 | output = involution(torch.rand(1, 8, 32, 32, 32)) 56 | ```` 57 | 58 | The 3D involution takes the following parameters. 59 | 60 | | Parameter | Description | Type | 61 | | ------------- | ------------- | ------------- | 62 | | in_channels | Number of input channels | int | 63 | | out_channels | Number of output channels | int | 64 | | sigma_mapping | Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized | Optional[nn.Module] | 65 | | kernel_size | Kernel size to be used (default=(7, 7, 7)) | Union[int, Tuple[int, int, int]] | 66 | | stride | Stride factor to be utilized (default=(1, 1, 1)) | Union[int, Tuple[int, int, int]] | 67 | | groups | Number of groups to be employed (default=1) | int | 68 | | reduce_ratio | Reduce ration of involution channels (default=1) | int | 69 | | dilation | Dilation in unfold to be employed (default=(1, 1, 1)) | Union[int, Tuple[int, int, int]] | 70 | | padding | Padding to be used in unfold operation (default=(3, 3, 3)) | Union[int, Tuple[int, int, int]] | 71 | | bias | If true bias is utilized in each convolution layer (default=False) | bool | 72 | | force_shape_match | If true potential shape mismatch is solved by performing avg pool (default=False) | bool | 73 | | **kwargs | Unused additional key word arguments | Any | 74 | 75 | 76 | ## Reference 77 | 78 | ````bibtex 79 | @inproceedings{Li2021, 80 | author = {Li, Duo and Hu, Jie and Wang, Changhu and Li, Xiangtai and She, Qi and Zhu, Lei and Zhang, Tong and Chen, Qifeng}, 81 | title = {Involution: Inverting the Inherence of Convolution for Visual Recognition}, 82 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 83 | month = {June}, 84 | year = {2021} 85 | } 86 | ```` 87 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from involution import Involution2d, Involution3d 3 | 4 | if __name__ == '__main__': 5 | # 2D involution example 6 | involution_2d = Involution2d(in_channels=4, out_channels=8) 7 | input = torch.rand(2, 4, 64, 64) 8 | output = involution_2d(input) 9 | 10 | # 2D involution as transposed convolution 11 | involution_2d = Involution2d(in_channels=6, out_channels=12) 12 | input_ = torch.rand(2, 6, 4, 4) 13 | input = torch.zeros(2, 6, 8, 8) 14 | input[..., ::2, ::2] = input_ 15 | output = involution_2d(input) 16 | 17 | # 2D involution with stride 18 | involution_2d = Involution2d(in_channels=4, out_channels=8, stride=2, kernel_size=2, padding=0) 19 | input = torch.rand(2, 4, 32, 32) 20 | output = involution_2d(input) 21 | 22 | # 3D involution example 23 | involution_3d = Involution3d(in_channels=8, out_channels=16) 24 | input = torch.rand(1, 8, 32, 32, 32) 25 | output = involution_3d(input) 26 | 27 | # 3D involution with stride 28 | involution_3d = Involution3d(in_channels=8, out_channels=16, stride=2, kernel_size=2, padding=0) 29 | input = torch.rand(1, 8, 16, 16, 16) 30 | output = involution_3d(input) -------------------------------------------------------------------------------- /involution/__init__.py: -------------------------------------------------------------------------------- 1 | # Import involution 2D and 3D 2 | from .involution import Involution2d, Involution3d 3 | 4 | __all__ = ["Involution2d", "Involution3d"] -------------------------------------------------------------------------------- /involution/involution.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | class Involution2d(nn.Module): 8 | """ 9 | This class implements the 2d involution proposed in: 10 | https://arxiv.org/pdf/2103.06255.pdf 11 | """ 12 | 13 | def __init__(self, 14 | in_channels: int, 15 | out_channels: int, 16 | sigma_mapping: Optional[nn.Module] = None, 17 | kernel_size: Union[int, Tuple[int, int]] = (7, 7), 18 | stride: Union[int, Tuple[int, int]] = (1, 1), 19 | groups: int = 1, 20 | reduce_ratio: int = 1, 21 | dilation: Union[int, Tuple[int, int]] = (1, 1), 22 | padding: Union[int, Tuple[int, int]] = (3, 3), 23 | bias: bool = False, 24 | force_shape_match: bool = False, 25 | **kwargs) -> None: 26 | """ 27 | Constructor method 28 | :param in_channels: (int) Number of input channels 29 | :param out_channels: (int) Number of output channels 30 | :param sigma_mapping: (nn.Module) Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized 31 | :param kernel_size: (Union[int, Tuple[int, int]]) Kernel size to be used 32 | :param stride: (Union[int, Tuple[int, int]]) Stride factor to be utilized 33 | :param groups: (int) Number of groups to be employed 34 | :param reduce_ratio: (int) Reduce ration of involution channels 35 | :param dilation: (Union[int, Tuple[int, int]]) Dilation in unfold to be employed 36 | :param padding: (Union[int, Tuple[int, int]]) Padding to be used in unfold operation 37 | :param bias: (bool) If true bias is utilized in each convolution layer 38 | :param force_shape_match: (bool) If true potential shape mismatch is solved by performing avg pool 39 | :param **kwargs: Unused additional key word arguments 40 | """ 41 | # Call super constructor 42 | super(Involution2d, self).__init__() 43 | # Check parameters 44 | assert isinstance(in_channels, int) and in_channels > 0, "in channels must be a positive integer." 45 | assert in_channels % groups == 0, "out_channels must be divisible by groups" 46 | assert isinstance(out_channels, int) and out_channels > 0, "out channels must be a positive integer." 47 | assert out_channels % groups == 0, "out_channels must be divisible by groups" 48 | assert isinstance(sigma_mapping, nn.Module) or sigma_mapping is None, \ 49 | "Sigma mapping must be an nn.Module or None to utilize the default mapping (BN + ReLU)." 50 | assert isinstance(kernel_size, int) or isinstance(kernel_size, tuple), \ 51 | "kernel size must be an int or a tuple of ints." 52 | assert isinstance(stride, int) or isinstance(stride, tuple), \ 53 | "stride must be an int or a tuple of ints." 54 | assert isinstance(groups, int), "groups must be a positive integer." 55 | assert isinstance(reduce_ratio, int) and reduce_ratio > 0, "reduce ratio must be a positive integer." 56 | assert isinstance(dilation, int) or isinstance(dilation, tuple), \ 57 | "dilation must be an int or a tuple of ints." 58 | assert isinstance(padding, int) or isinstance(padding, tuple), \ 59 | "padding must be an int or a tuple of ints." 60 | assert isinstance(bias, bool), "bias must be a bool" 61 | assert isinstance(force_shape_match, bool), "force shape match flag must be a bool" 62 | # Save parameters 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) 66 | self.stride = stride if isinstance(stride, tuple) else (stride, stride) 67 | self.groups = groups 68 | self.reduce_ratio = reduce_ratio 69 | self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) 70 | self.padding = padding if isinstance(padding, tuple) else (padding, padding) 71 | self.bias = bias 72 | self.force_shape_match = force_shape_match 73 | # Init modules 74 | self.sigma_mapping = sigma_mapping if sigma_mapping is not None else nn.Sequential( 75 | nn.BatchNorm2d(num_features=self.out_channels // self.reduce_ratio, momentum=0.3), nn.ReLU()) 76 | self.initial_mapping = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 77 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), 78 | bias=bias) if self.in_channels != self.out_channels else nn.Identity() 79 | self.o_mapping = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 80 | self.reduce_mapping = nn.Conv2d(in_channels=self.in_channels, 81 | out_channels=self.out_channels // self.reduce_ratio, kernel_size=(1, 1), 82 | stride=(1, 1), padding=(0, 0), bias=bias) 83 | self.span_mapping = nn.Conv2d(in_channels=self.out_channels // self.reduce_ratio, 84 | out_channels=self.kernel_size[0] * self.kernel_size[1] * self.groups, 85 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=bias) 86 | self.unfold = nn.Unfold(kernel_size=self.kernel_size, dilation=dilation, padding=padding, stride=stride) 87 | 88 | def __repr__(self) -> str: 89 | """ 90 | Method returns information about the module 91 | :return: (str) Info string 92 | """ 93 | return ("{}({}, {}, kernel_size=({}, {}), stride=({}, {}), padding=({}, {}), " 94 | "groups={}, reduce_ratio={}, dilation=({}, {}), bias={}, sigma_mapping={})".format( 95 | self.__class__.__name__, 96 | self.in_channels, 97 | self.out_channels, 98 | self.kernel_size[0], 99 | self.kernel_size[1], 100 | self.stride[0], 101 | self.stride[1], 102 | self.padding[0], 103 | self.padding[1], 104 | self.groups, 105 | self.reduce_mapping, 106 | self.dilation[0], 107 | self.dilation[1], 108 | self.bias, 109 | str(self.sigma_mapping) 110 | )) 111 | 112 | def forward(self, input: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Forward pass 115 | :param input: (torch.Tensor) Input tensor of the shape [batch size, in channels, height, width] 116 | :return: (torch.Tensor) Output tensor of the shape [batch size, out channels, height, width] (w/ same padding) 117 | """ 118 | # Check input dimension of input tensor 119 | assert input.ndimension() == 4, \ 120 | "Input tensor to involution must be 4d but {}d tensor is given".format(input.ndimension()) 121 | # Save input shape and compute output shapes 122 | batch_size, _, in_height, in_width = input.shape 123 | out_height = (in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) \ 124 | // self.stride[0] + 1 125 | out_width = (in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) \ 126 | // self.stride[1] + 1 127 | # Unfold and reshape input tensor 128 | input_unfolded = self.unfold(self.initial_mapping(input)) 129 | input_unfolded = input_unfolded.view(batch_size, self.groups, self.out_channels // self.groups, 130 | self.kernel_size[0] * self.kernel_size[1], 131 | out_height, out_width) 132 | # Reshape input to avoid shape mismatch problems 133 | if self.force_shape_match: 134 | input = F.adaptive_avg_pool2d(input,(out_height,out_width)) 135 | # Generate kernel 136 | kernel = self.span_mapping(self.sigma_mapping(self.reduce_mapping(self.o_mapping(input)))) 137 | kernel = kernel.view(batch_size, self.groups, self.kernel_size[0] * self.kernel_size[1], 138 | kernel.shape[-2], kernel.shape[-1]).unsqueeze(dim=2) 139 | # Apply kernel to produce output 140 | output = (kernel * input_unfolded).sum(dim=3) 141 | # Reshape output 142 | output = output.view(batch_size, -1, output.shape[-2], output.shape[-1]) 143 | return output 144 | 145 | 146 | class Involution3d(nn.Module): 147 | """ 148 | This class implements the 3d involution. 149 | """ 150 | 151 | def __init__(self, 152 | in_channels: int, 153 | out_channels: int, 154 | sigma_mapping: Optional[nn.Module] = None, 155 | kernel_size: Union[int, Tuple[int, int, int]] = (7, 7, 7), 156 | stride: Union[int, Tuple[int, int, int]] = (1, 1, 1), 157 | groups: int = 1, 158 | reduce_ratio: int = 1, 159 | dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1), 160 | padding: Union[int, Tuple[int, int, int]] = (3, 3, 3), 161 | bias: bool = False, 162 | force_shape_match: bool = False, 163 | **kwargs) -> None: 164 | """ 165 | Constructor method 166 | :param in_channels: (int) Number of input channels 167 | :param out_channels: (int) Number of output channels 168 | :param sigma_mapping: (nn.Module) Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized 169 | :param kernel_size: (Union[int, Tuple[int, int, int]]) Kernel size to be used 170 | :param stride: (Union[int, Tuple[int, int, int]]) Stride factor to be utilized 171 | :param groups: (int) Number of groups to be employed 172 | :param reduce_ratio: (int) Reduce ration of involution channels 173 | :param dilation: (Union[int, Tuple[int, int, int]]) Dilation in unfold to be employed 174 | :param padding: (Union[int, Tuple[int, int, int]]) Padding to be used in unfold operation 175 | :param bias: (bool) If true bias is utilized in each convolution layer 176 | :param force_shape_match: (bool) If true potential shape mismatch is solved by performing avg pool 177 | :param **kwargs: Unused additional key word arguments 178 | """ 179 | # Call super constructor 180 | super(Involution3d, self).__init__() 181 | # Check parameters 182 | assert isinstance(in_channels, int) and in_channels > 0, "in channels must be a positive integer." 183 | assert in_channels % groups == 0, "out_channels must be divisible by groups" 184 | assert isinstance(out_channels, int) and out_channels > 0, "out channels must be a positive integer." 185 | assert out_channels % groups == 0, "out_channels must be divisible by groups" 186 | assert isinstance(sigma_mapping, nn.Module) or sigma_mapping is None, \ 187 | "Sigma mapping must be an nn.Module or None to utilize the default mapping (BN + ReLU)." 188 | assert isinstance(kernel_size, int) or isinstance(kernel_size, tuple), \ 189 | "kernel size must be an int or a tuple of ints." 190 | assert isinstance(stride, int) or isinstance(stride, tuple), \ 191 | "stride must be an int or a tuple of ints." 192 | assert isinstance(groups, int), "groups must be a positive integer." 193 | assert isinstance(reduce_ratio, int) and reduce_ratio > 0, "reduce ratio must be a positive integer." 194 | assert isinstance(dilation, int) or isinstance(dilation, tuple), \ 195 | "dilation must be an int or a tuple of ints." 196 | assert isinstance(padding, int) or isinstance(padding, tuple), \ 197 | "padding must be an int or a tuple of ints." 198 | assert isinstance(bias, bool), "bias must be a bool" 199 | assert isinstance(force_shape_match, bool), "force shape match flag must be a bool" 200 | # Save parameters 201 | self.in_channels = in_channels 202 | self.out_channels = out_channels 203 | self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) 204 | self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) 205 | self.groups = groups 206 | self.reduce_ratio = reduce_ratio 207 | self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) 208 | self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) 209 | self.bias = bias 210 | self.force_shape_match = force_shape_match 211 | # Init modules 212 | self.sigma_mapping = sigma_mapping if sigma_mapping is not None else nn.Sequential( 213 | nn.BatchNorm3d(num_features=self.out_channels // self.reduce_ratio, momentum=0.3), nn.ReLU()) 214 | self.initial_mapping = nn.Conv3d( 215 | in_channels=self.in_channels, out_channels=self.out_channels, 216 | kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), 217 | bias=bias) if self.in_channels != self.out_channels else nn.Identity() 218 | self.o_mapping = nn.AvgPool3d(kernel_size=self.stride, stride=self.stride) 219 | self.reduce_mapping = nn.Conv3d( 220 | in_channels=self.in_channels, 221 | out_channels=self.out_channels // self.reduce_ratio, kernel_size=(1, 1, 1), 222 | stride=(1, 1, 1), padding=(0, 0, 0), bias=bias) 223 | self.span_mapping = nn.Conv3d( 224 | in_channels=self.out_channels // self.reduce_ratio, 225 | out_channels=self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] * self.groups, 226 | kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=bias) 227 | self.pad = nn.ConstantPad3d(padding=(self.padding[0], self.padding[0], 228 | self.padding[1], self.padding[1], 229 | self.padding[2], self.padding[2]), value=0.) 230 | 231 | def __repr__(self) -> str: 232 | """ 233 | Method returns information about the module 234 | :return: (str) Info string 235 | """ 236 | return ("{}({}, {}, kernel_size=({}, {}, {}), stride=({}, {}, {}), padding=({}, {}, {}), " 237 | "groups={}, reduce_ratio={}, dilation=({}, {}, {}), bias={}, sigma_mapping={})".format( 238 | self.__class__.__name__, 239 | self.in_channels, 240 | self.out_channels, 241 | self.kernel_size[0], 242 | self.kernel_size[1], 243 | self.kernel_size[2], 244 | self.stride[0], 245 | self.stride[1], 246 | self.stride[2], 247 | self.padding[0], 248 | self.padding[1], 249 | self.padding[2], 250 | self.groups, 251 | self.reduce_mapping, 252 | self.dilation[0], 253 | self.dilation[1], 254 | self.dilation[2], 255 | self.bias, 256 | str(self.sigma_mapping) 257 | )) 258 | 259 | def forward(self, input: torch.Tensor) -> torch.Tensor: 260 | """ 261 | Forward pass 262 | :param input: (torch.Tensor) Input tensor of the shape [batch size, in channels, depth, height, width] 263 | :return: (torch.Tensor) Output tensor of the shape [batch size, out channels, depth, height, width] (w/ same padding) 264 | """ 265 | # Check input dimension of input tensor 266 | assert input.ndimension() == 5, \ 267 | "Input tensor to involution must be 5d but {}d tensor is given".format(input.ndimension()) 268 | # Save input shapes and compute output shapes 269 | batch_size, _, in_depth, in_height, in_width = input.shape 270 | out_depth = (in_depth + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) \ 271 | // self.stride[0] + 1 272 | out_height = (in_height + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) \ 273 | // self.stride[1] + 1 274 | out_width = (in_width + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) \ 275 | // self.stride[2] + 1 276 | # Unfold and reshape input tensor 277 | input_initial = self.initial_mapping(input) 278 | input_unfolded = self.pad(input_initial) \ 279 | .unfold(dimension=2, size=self.kernel_size[0], step=self.stride[0]) \ 280 | .unfold(dimension=3, size=self.kernel_size[1], step=self.stride[1]) \ 281 | .unfold(dimension=4, size=self.kernel_size[2], step=self.stride[2]) 282 | input_unfolded = input_unfolded.reshape(batch_size, self.groups, self.out_channels // self.groups, 283 | self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], -1) 284 | input_unfolded = input_unfolded.reshape(tuple(input_unfolded.shape[:-1]) 285 | + (out_depth, out_height, out_width)) 286 | # Reshape input to avoid shape mismatch problems 287 | if self.force_shape_match: 288 | input = F.adaptive_avg_pool3d(input, (out_depth, out_height, out_width)) 289 | # Generate kernel 290 | kernel = self.span_mapping(self.sigma_mapping(self.reduce_mapping(self.o_mapping(input)))) 291 | kernel = kernel.view( 292 | batch_size, self.groups, self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], 293 | kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]).unsqueeze(dim=2) 294 | # Apply kernel to produce output 295 | output = (kernel * input_unfolded).sum(dim=3) 296 | # Reshape output 297 | output = output.view(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]) 298 | return output 299 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="involution", 5 | version="0.2", 6 | url="https://github.com/ChristophReich1996/Involution", 7 | license="MIT License", 8 | author="Christoph Reich", 9 | author_email="ChristophReich@gmx.net", 10 | description="PyTorch 2D/3D Involution", 11 | packages=["involution",], 12 | install_requires=["torch>=1.7.0"], 13 | ) 14 | --------------------------------------------------------------------------------