├── requirements.txt ├── README.md ├── LICENSE ├── .gitignore └── conv4d.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-conv4d 2 | 3 | This repository contains a simple PyTorch port of the [conv4d for TensorFlow repository](https://github.com/funkey/conv4d) by Jan Funke. It consists essentially of a single class, `Conv4d`, which provides a (still rather rudimentary) PyTorch layer for 4-dimensional convolutions. Like the original, it works by performing and stacking several 3D convolutions (see the original repository for a more detailed explanations). 4 | 5 | This implementation is still work in progress (hence it comes with no warranties whatsoever), and pull requests or advice for improvements are very much welcome! :) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Timothy Gebhard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /conv4d.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # IMPORTS 3 | # ----------------------------------------------------------------------------- 4 | 5 | from __future__ import division 6 | from typing import Tuple, Callable 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | # ----------------------------------------------------------------------------- 13 | # CLASS DEFINITIONS 14 | # ----------------------------------------------------------------------------- 15 | 16 | class Conv4d: 17 | 18 | def __init__(self, 19 | in_channels: int, 20 | out_channels: int, 21 | kernel_size: Tuple[int, int, int, int], 22 | stride: int = 1, 23 | padding: int = 0, 24 | dilation: int = 1, 25 | groups: int = 1, 26 | bias: bool = True, 27 | bias_initializer: Callable = None, 28 | kernel_initializer: Callable = None): 29 | 30 | super(Conv4d, self).__init__() 31 | 32 | # --------------------------------------------------------------------- 33 | # Assertions for constructor arguments 34 | # --------------------------------------------------------------------- 35 | 36 | assert len(kernel_size) == 4, \ 37 | '4D kernel size expected!' 38 | assert stride == 1, \ 39 | 'Strides other than 1 not yet implemented!' 40 | assert dilation == 1, \ 41 | 'Dilation rate other than 1 not yet implemented!' 42 | assert groups == 1, \ 43 | 'Groups other than 1 not yet implemented!' 44 | 45 | # --------------------------------------------------------------------- 46 | # Store constructor arguments 47 | # --------------------------------------------------------------------- 48 | 49 | self.in_channels = in_channels 50 | self.out_channels = out_channels 51 | self.kernel_size = kernel_size 52 | self.padding = padding 53 | self.groups = groups 54 | self.bias = bias 55 | 56 | self.bias_initializer = bias_initializer 57 | self.kernel_initializer = kernel_initializer 58 | 59 | # --------------------------------------------------------------------- 60 | # Construct 3D convolutional layers 61 | # --------------------------------------------------------------------- 62 | 63 | # Shortcut for kernel dimensions 64 | (l_k, d_k, h_k, w_k) = self.kernel_size 65 | 66 | # Use a ModuleList to store layers to make the Conv4d layer trainable 67 | self.conv3d_layers = torch.nn.ModuleList() 68 | 69 | for i in range(l_k): 70 | 71 | # Initialize a Conv3D layer 72 | conv3d_layer = torch.nn.Conv3d(in_channels=self.in_channels, 73 | out_channels=self.out_channels, 74 | kernel_size=(d_k, h_k, w_k), 75 | padding=self.padding) 76 | 77 | # Apply initializer functions to weight and bias tensor 78 | if self.kernel_initializer is not None: 79 | self.kernel_initializer(conv3d_layer.weight) 80 | if self.bias_initializer is not None: 81 | self.bias_initializer(conv3d_layer.bias) 82 | 83 | # Store the layer 84 | self.conv3d_layers.append(conv3d_layer) 85 | 86 | # ------------------------------------------------------------------------- 87 | 88 | def forward(self, input): 89 | 90 | # Define shortcut names for dimensions of input and kernel 91 | (b, c_i, l_i, d_i, h_i, w_i) = tuple(input.shape) 92 | (l_k, d_k, h_k, w_k) = self.kernel_size 93 | 94 | # Compute the size of the output tensor based on the zero padding 95 | (l_o, d_o, h_o, w_o) = (l_i + 2 * self.padding - l_k + 1, 96 | d_i + 2 * self.padding - d_k + 1, 97 | h_i + 2 * self.padding - h_k + 1, 98 | w_i + 2 * self.padding - w_k + 1) 99 | 100 | # Output tensors for each 3D frame 101 | frame_results = l_o * [None] 102 | 103 | # Convolve each kernel frame i with each input frame j 104 | for i in range(l_k): 105 | 106 | for j in range(l_i): 107 | 108 | # Add results to this output frame 109 | out_frame = j - (i - l_k // 2) - (l_i - l_o) // 2 110 | if out_frame < 0 or out_frame >= l_o: 111 | continue 112 | 113 | frame_conv3d = \ 114 | self.conv3d_layers[i](input[:, :, j, :] 115 | .view(b, c_i, d_i, h_i, w_i)) 116 | 117 | if frame_results[out_frame] is None: 118 | frame_results[out_frame] = frame_conv3d 119 | else: 120 | frame_results[out_frame] += frame_conv3d 121 | 122 | return torch.stack(frame_results, dim=2) 123 | 124 | 125 | # ----------------------------------------------------------------------------- 126 | # MAIN CODE (TO TEST CONV4D) 127 | # ----------------------------------------------------------------------------- 128 | 129 | if __name__ == "__main__": 130 | 131 | print() 132 | print('TEST PYTORCH CONV4D LAYER IMPLEMENTATION') 133 | print('\n' + 80 * '-' + '\n') 134 | 135 | # ------------------------------------------------------------------------- 136 | # Generate random input 4D tensor (+ batch dimension, + channel dimension) 137 | # ------------------------------------------------------------------------- 138 | 139 | np.random.seed(42) 140 | 141 | input_numpy = np.round(np.random.random((1, 1, 10, 11, 12, 13)) * 100) 142 | input_torch = torch.from_numpy(input_numpy).float() 143 | 144 | # ------------------------------------------------------------------------- 145 | # Convolve with a randomly initialized kernel 146 | # ------------------------------------------------------------------------- 147 | 148 | print('Randomly Initialized Kernels:\n') 149 | 150 | # Initialize the 4D convolutional layer with random kernels 151 | conv4d_layer = \ 152 | Conv4d(in_channels=1, 153 | out_channels=1, 154 | kernel_size=(3, 3, 3, 3), 155 | bias_initializer=lambda x: torch.nn.init.constant_(x, 0)) 156 | 157 | # Pass the input tensor through that layer 158 | output = conv4d_layer.forward(input_torch).data.numpy() 159 | 160 | # Select the 3D kernels for the manual computation and comparison 161 | kernels = [conv4d_layer.conv3d_layers[i].weight.data.numpy().flatten() 162 | for i in range(3)] 163 | 164 | # Compare the conv4d_layer result and the manual convolution computation 165 | # at 3 randomly chosen locations 166 | for i in range(3): 167 | 168 | # Randomly choose a location and select the conv4d_layer output 169 | loc = [np.random.randint(0, output.shape[2] - 2), 170 | np.random.randint(0, output.shape[3] - 2), 171 | np.random.randint(0, output.shape[4] - 2), 172 | np.random.randint(0, output.shape[5] - 2)] 173 | conv4d = output[0, 0, loc[0], loc[1], loc[2], loc[3]] 174 | 175 | # Select slices from the input tensor and compute manual convolution 176 | slices = [input_numpy[0, 0, loc[0] + j, loc[1]:loc[1] + 3, 177 | loc[2]:loc[2] + 3, loc[3]:loc[3] + 3].flatten() 178 | for j in range(3)] 179 | manual = np.sum([slices[j] * kernels[j] for j in range(3)]) 180 | 181 | # Print comparison 182 | print(f'At {tuple(loc)}:') 183 | print(f'\tconv4d:\t{conv4d}') 184 | print(f'\tmanual:\t{manual}') 185 | 186 | print('\n' + 80 * '-' + '\n') 187 | 188 | # ------------------------------------------------------------------------- 189 | # Convolve with a kernel initialized to be all ones 190 | # ------------------------------------------------------------------------- 191 | 192 | print('Constant Kernels (all 1):\n') 193 | 194 | conv4d_layer = \ 195 | Conv4d(in_channels=1, 196 | out_channels=1, 197 | kernel_size=(3, 3, 3, 3), 198 | padding=1, 199 | kernel_initializer=lambda x: torch.nn.init.constant_(x, 1), 200 | bias_initializer=lambda x: torch.nn.init.constant_(x, 0)) 201 | output = conv4d_layer.forward(input_torch) 202 | 203 | # Define relu(x) = max(x, 0) for simplified indexing below 204 | def relu(x: float) -> float: 205 | return x * (x > 0) 206 | 207 | # Compare the conv4d_layer result and the manual convolution computation 208 | # at 3 randomly chosen locations 209 | for i in range(3): 210 | 211 | # Randomly choose a location and select the conv4d_layer output 212 | loc = [np.random.randint(0, output.shape[2] - 2), 213 | np.random.randint(0, output.shape[3] - 2), 214 | np.random.randint(0, output.shape[4] - 2), 215 | np.random.randint(0, output.shape[5] - 2)] 216 | conv4d = output[0, 0, loc[0], loc[1], loc[2], loc[3]] 217 | 218 | # For a kernel that is all 1s, we only need to sum up the elements of 219 | # the input (the ReLU takes care of the padding!) 220 | manual = input_numpy[0, 0, 221 | relu(loc[0] - 1):loc[0] + 2, 222 | relu(loc[1] - 1):loc[1] + 2, 223 | relu(loc[2] - 1):loc[2] + 2, 224 | relu(loc[3] - 1):loc[3] + 2].sum() 225 | 226 | # Print comparison 227 | print(f'At {tuple(loc)}:') 228 | print(f'\tconv4d:\t{conv4d}') 229 | print(f'\tmanual:\t{manual}') 230 | 231 | print() 232 | --------------------------------------------------------------------------------