├── figures ├── kernel_size+stride.png ├── violating_symmetry.png └── explicit_next_neighbour_conv.png ├── LICENSE ├── notebooks ├── addressing_utils.py ├── hexagdly_tools.py └── example_utils.py ├── tests ├── test_MaxPool2d.py ├── test_Conv2d_Custom.py ├── test_MaxPool3d.py ├── test_Conv2d.py ├── test_Conv2d_CustomKernel.py ├── test_Conv3d.py └── test_Conv3d_CustomKernel.py ├── README.md └── src └── hexagdly.py /figures/kernel_size+stride.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4iacts/hexagdly/HEAD/figures/kernel_size+stride.png -------------------------------------------------------------------------------- /figures/violating_symmetry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4iacts/hexagdly/HEAD/figures/violating_symmetry.png -------------------------------------------------------------------------------- /figures/explicit_next_neighbour_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4iacts/hexagdly/HEAD/figures/explicit_next_neighbour_conv.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 ai4iacts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /notebooks/addressing_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.colors as mcolors 3 | from matplotlib.patches import RegularPolygon 4 | from matplotlib.collections import PatchCollection 5 | 6 | class Detector: 7 | def __init__(self): 8 | # hexagon size 9 | self.a = 1 10 | self.r = .5*3**(.5)*self.a 11 | # half the number of rows and columns 12 | ny = 5 13 | nx = 5 14 | # grid of odd columns 15 | y1 = np.linspace(0,2*ny*self.r,ny+1) 16 | x1 = np.linspace(0,2*nx*1.5*self.a,nx+1) 17 | X1,Y1 = np.meshgrid(x1,y1) 18 | # grid of even columns 19 | y2 = np.linspace(self.r,(2*ny+1)*self.r,ny+1) 20 | x2 = np.linspace(1.5*self.a,(2*nx+1)*1.5*self.a,nx+1) 21 | X2,Y2 = np.meshgrid(x2,y2) 22 | # join grids 23 | x = np.concatenate([X1,X2]) 24 | y = np.concatenate([Y1,Y2]) 25 | # select pixels within a circular region 26 | xc = .5*(x.max()-x.min()) 27 | yc = .5*(y.max()-y.min()) 28 | x,y = x-xc, y-yc 29 | r = np.sqrt(x**2 + y**2) 30 | ind = np.where(r= inchannels: 58 | a = i 59 | b = 0 60 | else: 61 | a = 0 62 | b = i 63 | npixel = 0 64 | for x in range(np.shape(tensor[image_range[0] + a, channel_range[0] + b])[1]): 65 | for y in range( 66 | np.shape(tensor[image_range[0] + a, channel_range[0] + b])[0] 67 | ): 68 | if npixel not in mask: 69 | intensity = tensor[image_range[0] + a, channel_range[0] + b, y, x] 70 | hexagon = RegularPolygon( 71 | (x * np.sqrt(3) / 2, -(y + np.mod(x, 2) * 0.5)), 72 | 6, 73 | 0.577349, 74 | orientation=np.pi / 6, 75 | ) 76 | intensities[i].append(intensity) 77 | hexagons[i].append(hexagon) 78 | npixel += 1 79 | ax = fig.add_subplot(gs[i]) 80 | ax.set_xlim([-1, np.shape(tensor[image_range[0] + a, channel_range[0] + b])[0]]) 81 | ax.set_ylim( 82 | [ 83 | -1.15 * np.shape(tensor[image_range[0] + a, channel_range[0] + b])[1] 84 | - 1, 85 | 1, 86 | ] 87 | ) 88 | ax.set_axis_off() 89 | p = PatchCollection( 90 | np.array(hexagons[i]), cmap=cmap, alpha=0.9, edgecolors="k", linewidth=1 91 | ) 92 | p.set_array(np.array(np.array(intensities[i]))) 93 | ax.add_collection(p) 94 | ax.set_aspect("equal") 95 | plt.subplots_adjust(top=0.95, bottom=0.05) 96 | plt.tight_layout() 97 | 98 | 99 | def plot_squaretensor( 100 | tensor, 101 | image_range=(0, None), 102 | channel_range=(0, None), 103 | cmap="Greys", 104 | figname="figure", 105 | ): 106 | r""" Same as plot_hex_tensor, just that the tensor is plotted in squares 107 | in a cartesian grid. 108 | 109 | """ 110 | try: 111 | tensor = tensor.data.numpy() 112 | except Exception as e: 113 | print("Input not given as pytorch tensor! Continuing...") 114 | pass 115 | inshape = np.shape( 116 | tensor[image_range[0] : image_range[1], channel_range[0] : channel_range[1]] 117 | ) 118 | inexamples = inshape[0] 119 | inchannels = inshape[1] 120 | if inexamples != 1 and inchannels != 1: 121 | print("Choose one image and n channels or one channel an n images to display!") 122 | sys.exit() 123 | nimages = max(inexamples, inchannels) 124 | fig = plt.figure(figname, (6, 6)) 125 | fig.clear() 126 | nrows = int(np.ceil(np.sqrt(nimages))) 127 | gs = gridspec.GridSpec(nrows, nrows) 128 | gs.update(wspace=0.2, hspace=0) 129 | for i in range(nimages): 130 | if inexamples >= inchannels: 131 | a = i 132 | b = 0 133 | else: 134 | a = 0 135 | b = i 136 | npixel = 0 137 | ax = fig.add_subplot(gs[i]) 138 | ax.set_axis_off() 139 | ax.pcolor(tensor[a][b], cmap=cmap, edgecolors="k", linewidths=0.4) 140 | ax.invert_yaxis() 141 | ax.set_aspect("equal") 142 | ax.set_frame_on(True) 143 | plt.subplots_adjust(top=0.95, bottom=0.05) 144 | -------------------------------------------------------------------------------- /tests/test_MaxPool2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestMaxPool2d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_maxpool2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [6, 12, 16, 22, 26, 32, 36, 37], 17 | [7, 13, 17, 23, 27, 33, 37, 38], 18 | [8, 14, 18, 24, 28, 34, 38, 39], 19 | [9, 15, 19, 25, 29, 35, 39, 40], 20 | [10, 15, 20, 25, 30, 35, 40, 40], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_maxpool2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [12, 17, 22, 27, 32, 37, 37, 38], 29 | [13, 18, 23, 28, 33, 38, 38, 39], 30 | [14, 19, 24, 29, 34, 39, 39, 40], 31 | [15, 20, 25, 30, 35, 40, 40, 40], 32 | [15, 20, 25, 30, 35, 40, 40, 40], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_tensors(self, in_channels, kernel_size, stride): 68 | channel_dist = 1000 69 | 70 | # input tensor 71 | array = self.get_array() 72 | array = np.expand_dims( 73 | np.stack([j * channel_dist + array for j in range(in_channels)]), 0 74 | ) 75 | tensor = torch.FloatTensor(array) 76 | 77 | # expected output tensor 78 | if kernel_size == 1: 79 | pooled_array = self.get_array_maxpool2d_size1_stride1() 80 | elif kernel_size == 2: 81 | pooled_array = self.get_array_maxpool2d_size2_stride1() 82 | if stride == 2: 83 | pooled_array = self.get_array_stride_2(pooled_array) 84 | elif stride == 3: 85 | pooled_array = self.get_array_stride_3(pooled_array) 86 | pooled_array = np.expand_dims( 87 | np.stack( 88 | [ 89 | channel * channel_dist + pooled_array 90 | for channel in range(in_channels) 91 | ] 92 | ), 93 | 0, 94 | ) 95 | pooled_tensor = torch.FloatTensor(pooled_array) 96 | 97 | # output tensor of test method 98 | maxpool2d = hex.MaxPool2d(kernel_size, stride) 99 | 100 | return maxpool2d(tensor), pooled_tensor 101 | 102 | def test_in_channels_1_kernel_size_1_stride_1(self): 103 | in_channels = 1 104 | kernel_size = 1 105 | stride = 1 106 | 107 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 108 | 109 | assert torch.equal(test_ouput, expectation) 110 | 111 | def test_in_channels_1_kernel_size_1_stride_2(self): 112 | in_channels = 1 113 | kernel_size = 1 114 | stride = 2 115 | 116 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 117 | 118 | assert torch.equal(test_ouput, expectation) 119 | 120 | def test_in_channels_1_kernel_size_1_stride_3(self): 121 | in_channels = 1 122 | kernel_size = 1 123 | stride = 3 124 | 125 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 126 | 127 | assert torch.equal(test_ouput, expectation) 128 | 129 | def test_in_channels_1_kernel_size_2_stride_1(self): 130 | in_channels = 1 131 | kernel_size = 2 132 | stride = 1 133 | 134 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 135 | 136 | assert torch.equal(test_ouput, expectation) 137 | 138 | def test_in_channels_1_kernel_size_2_stride_2(self): 139 | in_channels = 1 140 | kernel_size = 2 141 | stride = 2 142 | 143 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 144 | 145 | assert torch.equal(test_ouput, expectation) 146 | 147 | def test_in_channels_1_kernel_size_2_stride_3(self): 148 | in_channels = 1 149 | kernel_size = 2 150 | stride = 3 151 | 152 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 153 | 154 | assert torch.equal(test_ouput, expectation) 155 | 156 | def test_in_channels_5_kernel_size_1_stride_1(self): 157 | in_channels = 5 158 | kernel_size = 1 159 | stride = 1 160 | 161 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 162 | 163 | assert torch.equal(test_ouput, expectation) 164 | 165 | def test_in_channels_5_kernel_size_1_stride_2(self): 166 | in_channels = 5 167 | kernel_size = 1 168 | stride = 2 169 | 170 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 171 | 172 | assert torch.equal(test_ouput, expectation) 173 | 174 | def test_in_channels_5_kernel_size_1_stride_3(self): 175 | in_channels = 5 176 | kernel_size = 1 177 | stride = 3 178 | 179 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 180 | 181 | assert torch.equal(test_ouput, expectation) 182 | 183 | def test_in_channels_5_kernel_size_2_stride_1(self): 184 | in_channels = 5 185 | kernel_size = 2 186 | stride = 1 187 | 188 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 189 | 190 | assert torch.equal(test_ouput, expectation) 191 | 192 | def test_in_channels_5_kernel_size_2_stride_2(self): 193 | in_channels = 5 194 | kernel_size = 2 195 | stride = 2 196 | 197 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 198 | 199 | assert torch.equal(test_ouput, expectation) 200 | 201 | def test_in_channels_5_kernel_size_2_stride_3(self): 202 | in_channels = 5 203 | kernel_size = 2 204 | stride = 3 205 | 206 | test_ouput, expectation = self.get_tensors(in_channels, kernel_size, stride) 207 | 208 | assert torch.equal(test_ouput, expectation) 209 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HexagDLy - Processing Hexagonal Data with PyTorch 2 | 3 | HexagDLy provides convolution and pooling methods for hexagonally sampled data within the deep learning framework [PyTorch](https://github.com/pytorch/pytorch). 4 | 5 | - [Getting Started](#getting-started) 6 | - [Preparing the Data](#preparing-the-data) 7 | - [How to use HexagDLy](#how-to-use-hexagdly) 8 | - [General Concept](#general-concept) 9 | - [Disclaimer](#disclaimer) 10 | - [Citing HexagDLy](#citation) 11 | 12 | 13 | ## Getting Started 14 | 15 | There are different ways to get HexagDLy up and running on your system as shown below. Basic examples for the application of HexagDLy are given as [notebooks](notebooks). Additionally unit tests are provided in [tests](tests). 16 | 17 | ### Pip Installation 18 | 19 | The suggested way to install HexagDLy is to set up a clean virtual python environment (e.g. with conda, see [https://www.anaconda.com/](https://www.anaconda.com/)) and use the provided pip installer. To install basic functionalities only use: 20 | 21 | ``` 22 | pip install hexagdly 23 | ``` 24 | 25 | To get all necessary dependencies to run the provided [unit tests](tests) and [notebooks](notebooks), add the ```dev``` option: 26 | 27 | ``` 28 | pip install hexagdly[dev] 29 | ``` 30 | 31 | 32 | ### Manual Installation 33 | 34 | HexagDLy requires a working installation of [PyTorch](https://github.com/pytorch/pytorch). Please visit the PyTorch website http://pytorch.org/ or github page https://github.com/pytorch/pytorch and follow the installation instructions. 35 | If you have downloaded HexagDLy, just add the directory `hexagdly/src` to your system's `$PYTHONPATH`, e.g. by adding the following line to your `.bashrc` or `.bash_profile` (or wherever your paths can be set) 36 | 37 | ``` 38 | export PYTHONPATH='/path/to/hexagdly/src':$PYTHONPATH 39 | ``` 40 | 41 | ## How to use HexagDLy 42 | 43 | As HexagDLy is based on PyTorch, it is of advantage to be familiar with PyTorch's functionalities and concepts. 44 | Furthermore, before applying HexagDLy, it should be ensured that the input data has the correct format. HexagDLy uses an addressing scheme to map data from a hexagonal grid to a torch tensor. An [example](notebooks/how_to_apply_adressing_scheme.ipynb) is provided that illustrates the steps to get the data into the correct layout. 45 | 46 | If the data has this required layout and HexagDLy is installed, performing hexagonal convolutions is as easy as the following example: 47 | 48 | ``` 49 | import torch 50 | import hexagdly 51 | 52 | kernel_size, stride = 1, 4 53 | in_channels, out_channels = 1, 3 54 | 55 | hexconv = hexagdly.Conv2d(in_channels, out_channels, kernel_size, stride) 56 | input = torch.rand(1,1,21,21) 57 | output = hexconv(input) 58 | ``` 59 | 60 | In this example, a random input tensor of shape (1, 1, 21, 21) is convolved with a so called next neighbour hexagonal kernel (size = 1) with one input channel and three output channels, using a stride of four. The output is a tensor of size (1, 3, 5, 6). 61 | 62 | HexagDLy is desinged to conserve the hexagonal symmetry of the input. Therefore, a hexagonal kernel is always 6-fold symmetric and may only move along the symmetry axes of the grid in symmetric steps. 63 | An automatic padding of the input is applied, depending on the kernel size, stride and dimensions of the input. 64 | The image below shows examples of how kernels of different size and stride visit certain regions of an input. The orange cells mark the hexagonal kernel centered on the top left cell, the starting point of each operation. The square gridlines depict cells on which the kernel is centered by moving it with the given stride. 65 | 66 | ![kerne size+stride](figures/kernel_size+stride.png "Examples of different kernels of different size and strides.") 67 | 68 | **Please note**: Operations are only performed where the center point of a kernel is located within the input tensor. This could result in output columns of different length. In such cases the output will be sliced according to the shortest column. An example is the convolution with stride 3 in the center of the figure above. The red gridlines depict convolutions that are omitted in the output. 69 | 70 | Examples for basic use-cases of HexagDLy are shown in the [notebooks](notebooks) folder. 71 | 72 | 73 | ## General Concept 74 | 75 | As common deep learning frameworks are designed to process data arranged in square grids, it is not (yet) possible out-of-the-box to process data points that are arranged on a hexagonal grid. 76 | To process hexagonally sampled data with frameworks like PyTorch, it is therefore necessary to translate the information from the hexagonal grid to a square grid tensor. 77 | Such a conversion is however not trivial, as square and hexagonal grids inhibit different symmetries, i.e. 4-fold symmetry vs. 6-fold symmetry. 78 | 79 | This problem is solved in HexagDLy by using an addressing scheme that maps the data from its original hexagonal grid to a square grid and by adapting the convolution operations to regard the symmetry of the hexagonal grid. The applied addressing scheme basically aligns the columns of the hexagonal grid. 80 | By applying a standard square-grid convolution to this data, the kernel disregards the original pixel-to-pixel neighbour relation and breaks the 6-fold symmetry. In order to perform a valid hexagonal convolution it is necessary to split the kernel into sub-kernels that, in combination, cover the true neighbours of a data point in the hexagonal grid. The concept is depicted in the image below: 81 | 82 | ![violating_symmetry](figures/violating_symmetry.png "Squeezing hexagonal data in a square grid and applying square convolution kernels disregards the symmetry of the hexagonal lattice. A valid hexagonal convolution can be performed by combining custom sub-kernels.") 83 | 84 | Due to the alternating shift between the columns of the array of data points, the sub-kernels of a hexagonal convolution kernel have to shift accordingly, depending on whether the kernel is centered on an odd or an even column of the array. 85 | A full hexagonal convolution with the smallest hexagonal kernel (size 1) that conserves the dimensions of the input can be broken down into a total of three sub-convolutions that are performed by applying two different sub-kernels to three differently padded versions of the input. The resulting arrays are then merged and added to obtain the desired result. 86 | The individual steps of this operation are depicted in the image below, where a toy input tensor is convolved with hexagonal size 1 kernel (all weights set to 1, i.e. the convolution adds up all data points covered by the kernel): 87 | 88 | ![explicit_next_neighbour_conv](figures/explicit_next_neighbour_conv.png "Schematic description of the individual sub-onvolutions and combination of the individual outputs to perform a hexagonal convolution as provided by HexagDLy.") 89 | 90 | Following the same concept, it is feasible to construct larger hexagonal convolution kernels as well as pooling operations by increasing the strides of the sub-kernels and exchanging the convolution operations with nested pooling methods. 91 | 92 | 93 | ## Disclaimer 94 | 95 | HexagDLy is built as an easy-to-use prototyping tool to design convolutional neural networks for hexagonally sampled data. The implemented methods rather aim for flexibility then for performance. 96 | Once a model is optimized, it is possible to hard-code the desired parameters like kernel size, stride and input dimensions to make the implementation faster. 97 | Furthermore, the [General Concept](#general-concept) is not specific to PyTorch but can be adapted to other deep learning frameworks. 98 | 99 | 100 | ## Authors 101 | 102 | * **Tim Lukas Holch** 103 | * **Constantin Steppa** 104 | 105 | See also the list of [contributors](https://github.com/ai4iacts/hexagdly/contributors) who participated in this project. 106 | 107 | 108 | ## License 109 | 110 | This project is licensed under the MIT license - please consult the [LICENSE](LICENSE) file for details. 111 | 112 | 113 | ## Citation 114 | 115 | We have published an open access paper about HexagDLy in [SoftwareX](https://www.sciencedirect.com/science/article/pii/S2352711018302723). If this work has helped your research, please cite us via: 116 | 117 | ``` 118 | @article{hexagdly_paper, 119 | title = "HexagDLy—Processing hexagonally sampled data with CNNs in PyTorch", 120 | author = "Constantin Steppa and Tim L. Holch", 121 | journal = "SoftwareX", 122 | volume = "9", 123 | pages = "193 - 198", 124 | year = "2019", 125 | issn = "2352-7110", 126 | doi = "https://doi.org/10.1016/j.softx.2019.02.010", 127 | url = "https://www.sciencedirect.com/science/article/pii/S2352711018302723", 128 | keywords = "Convolutional neural networks, Hexagonal grid, PyTorch, Astroparticle physics", 129 | abstract = "HexagDLy is a Python-library extending the PyTorch deep learning framework with convolution and pooling operations on hexagonal grids. It aims to ease the access to convolutional neural networks for applications that rely on hexagonally sampled data as, for example, commonly found in ground-based astroparticle physics experiments." 130 | } 131 | ``` 132 | 133 | HexagDLy was developed as part of a research study in the field of ground-based gamma-ray astronomy published in [Astroparticle Physics](https://doi.org/10.1016/j.astropartphys.2018.10.003). 134 | 135 | 136 | ## Acknowledgments 137 | 138 | This project evolved by exploring new analysis techniques for Imaging Atmospheric Cherenkov Telescopes with the High Energy Stereoscopic System (H.E.S.S.). We would like to thank the members of the H.E.S.S. collaboration for their support. 139 | 140 | 141 | -------------------------------------------------------------------------------- /tests/test_Conv2d_Custom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestConv2d(object): 8 | def get_in_array(self): 9 | return np.array( 10 | [ 11 | [ 12 | [ 13 | [0, 0, 0, 0, 1, 0], 14 | [0, 0, 1, 0, 0, 0], 15 | [0, 0, 0, 0, 1, 0], 16 | [0, 0, 1, 0, 0, 1], 17 | ] 18 | ] 19 | ], 20 | dtype=np.float32, 21 | ) 22 | 23 | def get_kernel_1_stride_1_array(self): 24 | return np.array( 25 | [ 26 | [ 27 | [ 28 | [0, 1, 1, 2, 1, 1], 29 | [0, 1, 1, 2, 2, 1], 30 | [0, 1, 2, 2, 1, 2], 31 | [0, 1, 1, 1, 2, 1], 32 | ] 33 | ] 34 | ], 35 | dtype=np.float32, 36 | ) 37 | 38 | def get_kernel_1_stride_2_array(self): 39 | return np.array([[[[0, 1, 1], [0, 1, 1]]]], dtype=np.float32) 40 | 41 | def get_kernel_1_stride_3_array(self): 42 | return np.array([[[[0, 2]]]], dtype=np.float32) 43 | 44 | def get_kernel_2_stride_1_array(self): 45 | return np.array( 46 | [ 47 | [ 48 | [ 49 | [1, 1, 2, 3, 3, 2], 50 | [1, 2, 4, 4, 3, 3], 51 | [2, 2, 3, 4, 5, 2], 52 | [1, 1, 3, 3, 3, 2], 53 | ] 54 | ] 55 | ], 56 | dtype=np.float32, 57 | ) 58 | 59 | def get_kernel_2_stride_2_array(self): 60 | return np.array([[[[1, 4, 3], [2, 3, 5]]]], dtype=np.float32) 61 | 62 | def get_kernel_2_stride_3_array(self): 63 | return np.array([[[[1, 4]]]], dtype=np.float32) 64 | 65 | def get_tensors(self, in_channels, kernel_size, stride, bias_bool): 66 | channel_dist = 1000 67 | if bias_bool is False: 68 | bias_value = 0 69 | bias = None 70 | else: 71 | bias_value = 1.0 72 | bias = np.array([1]) 73 | 74 | # input tensor 75 | array = self.get_in_array() 76 | array = np.concatenate( 77 | [channel * channel_dist * array + array for channel in range(in_channels)], 78 | 1, 79 | ) 80 | tensor = torch.FloatTensor(array) 81 | 82 | # expected output tensor 83 | convolved_array = getattr( 84 | self, "get_kernel_" + str(kernel_size) + "_stride_" + str(stride) + "_array" 85 | )() 86 | convolved_array = np.sum( 87 | np.stack( 88 | [ 89 | (channel * channel_dist) * convolved_array + convolved_array 90 | for channel in range(in_channels) 91 | ] 92 | ), 93 | 0, 94 | ) 95 | convolved_tensor = torch.FloatTensor(convolved_array) + bias_value 96 | 97 | # output tensor of test method 98 | if kernel_size == 1: 99 | kernel = [np.ones((1, in_channels, 3, 1)), np.ones((1, in_channels, 2, 2))] 100 | elif kernel_size == 2: 101 | kernel = [ 102 | np.ones((1, in_channels, 5, 1)), 103 | np.ones((1, in_channels, 4, 2)), 104 | np.ones((1, in_channels, 3, 2)), 105 | ] 106 | conv2d = hex.Conv2d_CustomKernel(kernel, stride, bias) 107 | 108 | return conv2d(tensor), convolved_tensor 109 | 110 | def test_in_channels_1_kernel_size_1_stride_1_bias_False(self): 111 | in_channels = 1 112 | kernel_size = 1 113 | stride = 1 114 | bias = False 115 | 116 | test_ouput, expectation = self.get_tensors( 117 | in_channels, kernel_size, stride, bias 118 | ) 119 | 120 | assert torch.equal(test_ouput, expectation) 121 | 122 | def test_in_channels_1_kernel_size_1_stride_2_bias_False(self): 123 | in_channels = 1 124 | kernel_size = 1 125 | stride = 2 126 | bias = False 127 | 128 | test_ouput, expectation = self.get_tensors( 129 | in_channels, kernel_size, stride, bias 130 | ) 131 | 132 | assert torch.equal(test_ouput, expectation) 133 | 134 | def test_in_channels_1_kernel_size_1_stride_3_bias_False(self): 135 | in_channels = 1 136 | kernel_size = 1 137 | stride = 3 138 | bias = False 139 | 140 | test_ouput, expectation = self.get_tensors( 141 | in_channels, kernel_size, stride, bias 142 | ) 143 | 144 | assert torch.equal(test_ouput, expectation) 145 | 146 | def test_in_channels_1_kernel_size_2_stride_1_bias_False(self): 147 | in_channels = 1 148 | kernel_size = 2 149 | stride = 1 150 | bias = False 151 | 152 | test_ouput, expectation = self.get_tensors( 153 | in_channels, kernel_size, stride, bias 154 | ) 155 | 156 | assert torch.equal(test_ouput, expectation) 157 | 158 | def test_in_channels_1_kernel_size_2_stride_2_bias_False(self): 159 | in_channels = 1 160 | kernel_size = 2 161 | stride = 2 162 | bias = False 163 | 164 | test_ouput, expectation = self.get_tensors( 165 | in_channels, kernel_size, stride, bias 166 | ) 167 | 168 | assert torch.equal(test_ouput, expectation) 169 | 170 | def test_in_channels_1_kernel_size_2_stride_3_bias_False(self): 171 | in_channels = 1 172 | kernel_size = 2 173 | stride = 3 174 | bias = False 175 | 176 | test_ouput, expectation = self.get_tensors( 177 | in_channels, kernel_size, stride, bias 178 | ) 179 | 180 | assert torch.equal(test_ouput, expectation) 181 | 182 | def test_in_channels_5_kernel_size_1_stride_1_bias_False(self): 183 | in_channels = 5 184 | kernel_size = 1 185 | stride = 1 186 | bias = False 187 | 188 | test_ouput, expectation = self.get_tensors( 189 | in_channels, kernel_size, stride, bias 190 | ) 191 | 192 | assert torch.equal(test_ouput, expectation) 193 | 194 | def test_in_channels_5_kernel_size_1_stride_2_bias_False(self): 195 | in_channels = 5 196 | kernel_size = 1 197 | stride = 2 198 | bias = False 199 | 200 | test_ouput, expectation = self.get_tensors( 201 | in_channels, kernel_size, stride, bias 202 | ) 203 | 204 | assert torch.equal(test_ouput, expectation) 205 | 206 | def test_in_channels_5_kernel_size_1_stride_3_bias_False(self): 207 | in_channels = 5 208 | kernel_size = 1 209 | stride = 3 210 | bias = False 211 | 212 | test_ouput, expectation = self.get_tensors( 213 | in_channels, kernel_size, stride, bias 214 | ) 215 | 216 | assert torch.equal(test_ouput, expectation) 217 | 218 | def test_in_channels_5_kernel_size_2_stride_1_bias_False(self): 219 | in_channels = 5 220 | kernel_size = 2 221 | stride = 1 222 | bias = False 223 | 224 | test_ouput, expectation = self.get_tensors( 225 | in_channels, kernel_size, stride, bias 226 | ) 227 | 228 | assert torch.equal(test_ouput, expectation) 229 | 230 | def test_in_channels_5_kernel_size_2_stride_2_bias_False(self): 231 | in_channels = 5 232 | kernel_size = 2 233 | stride = 2 234 | bias = False 235 | 236 | test_ouput, expectation = self.get_tensors( 237 | in_channels, kernel_size, stride, bias 238 | ) 239 | 240 | assert torch.equal(test_ouput, expectation) 241 | 242 | def test_in_channels_5_kernel_size_2_stride_3_bias_False(self): 243 | in_channels = 5 244 | kernel_size = 2 245 | stride = 3 246 | bias = False 247 | 248 | test_ouput, expectation = self.get_tensors( 249 | in_channels, kernel_size, stride, bias 250 | ) 251 | 252 | assert torch.equal(test_ouput, expectation) 253 | 254 | def test_in_channels_1_kernel_size_1_stride_1_bias_True(self): 255 | in_channels = 1 256 | kernel_size = 1 257 | stride = 1 258 | bias = True 259 | 260 | test_ouput, expectation = self.get_tensors( 261 | in_channels, kernel_size, stride, bias 262 | ) 263 | 264 | assert torch.equal(test_ouput, expectation) 265 | 266 | def test_in_channels_1_kernel_size_1_stride_2_bias_True(self): 267 | in_channels = 1 268 | kernel_size = 1 269 | stride = 2 270 | bias = True 271 | 272 | test_ouput, expectation = self.get_tensors( 273 | in_channels, kernel_size, stride, bias 274 | ) 275 | 276 | assert torch.equal(test_ouput, expectation) 277 | 278 | def test_in_channels_1_kernel_size_1_stride_3_bias_True(self): 279 | in_channels = 1 280 | kernel_size = 1 281 | stride = 3 282 | bias = True 283 | 284 | test_ouput, expectation = self.get_tensors( 285 | in_channels, kernel_size, stride, bias 286 | ) 287 | 288 | assert torch.equal(test_ouput, expectation) 289 | 290 | def test_in_channels_1_kernel_size_2_stride_1_bias_True(self): 291 | in_channels = 1 292 | kernel_size = 2 293 | stride = 1 294 | bias = True 295 | 296 | test_ouput, expectation = self.get_tensors( 297 | in_channels, kernel_size, stride, bias 298 | ) 299 | 300 | assert torch.equal(test_ouput, expectation) 301 | 302 | def test_in_channels_1_kernel_size_2_stride_2_bias_True(self): 303 | in_channels = 1 304 | kernel_size = 2 305 | stride = 2 306 | bias = True 307 | 308 | test_ouput, expectation = self.get_tensors( 309 | in_channels, kernel_size, stride, bias 310 | ) 311 | 312 | assert torch.equal(test_ouput, expectation) 313 | 314 | def test_in_channels_1_kernel_size_2_stride_3_bias_True(self): 315 | in_channels = 1 316 | kernel_size = 2 317 | stride = 3 318 | bias = True 319 | 320 | test_ouput, expectation = self.get_tensors( 321 | in_channels, kernel_size, stride, bias 322 | ) 323 | 324 | assert torch.equal(test_ouput, expectation) 325 | 326 | def test_in_channels_5_kernel_size_1_stride_1_bias_True(self): 327 | in_channels = 5 328 | kernel_size = 1 329 | stride = 1 330 | bias = True 331 | 332 | test_ouput, expectation = self.get_tensors( 333 | in_channels, kernel_size, stride, bias 334 | ) 335 | 336 | assert torch.equal(test_ouput, expectation) 337 | 338 | def test_in_channels_5_kernel_size_1_stride_2_bias_True(self): 339 | in_channels = 5 340 | kernel_size = 1 341 | stride = 2 342 | bias = True 343 | 344 | test_ouput, expectation = self.get_tensors( 345 | in_channels, kernel_size, stride, bias 346 | ) 347 | 348 | assert torch.equal(test_ouput, expectation) 349 | 350 | def test_in_channels_5_kernel_size_1_stride_3_bias_True(self): 351 | in_channels = 5 352 | kernel_size = 1 353 | stride = 3 354 | bias = True 355 | 356 | test_ouput, expectation = self.get_tensors( 357 | in_channels, kernel_size, stride, bias 358 | ) 359 | 360 | assert torch.equal(test_ouput, expectation) 361 | 362 | def test_in_channels_5_kernel_size_2_stride_1_bias_True(self): 363 | in_channels = 5 364 | kernel_size = 2 365 | stride = 1 366 | bias = True 367 | 368 | test_ouput, expectation = self.get_tensors( 369 | in_channels, kernel_size, stride, bias 370 | ) 371 | 372 | assert torch.equal(test_ouput, expectation) 373 | 374 | def test_in_channels_5_kernel_size_2_stride_2_bias_True(self): 375 | in_channels = 5 376 | kernel_size = 2 377 | stride = 2 378 | bias = True 379 | 380 | test_ouput, expectation = self.get_tensors( 381 | in_channels, kernel_size, stride, bias 382 | ) 383 | 384 | assert torch.equal(test_ouput, expectation) 385 | 386 | def test_in_channels_5_kernel_size_2_stride_3_bias_True(self): 387 | in_channels = 5 388 | kernel_size = 2 389 | stride = 3 390 | bias = True 391 | 392 | test_ouput, expectation = self.get_tensors( 393 | in_channels, kernel_size, stride, bias 394 | ) 395 | 396 | assert torch.equal(test_ouput, expectation) 397 | -------------------------------------------------------------------------------- /tests/test_MaxPool3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestMaxPool3d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_maxpool2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [6, 12, 16, 22, 26, 32, 36, 37], 17 | [7, 13, 17, 23, 27, 33, 37, 38], 18 | [8, 14, 18, 24, 28, 34, 38, 39], 19 | [9, 15, 19, 25, 29, 35, 39, 40], 20 | [10, 15, 20, 25, 30, 35, 40, 40], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_maxpool2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [12, 17, 22, 27, 32, 37, 37, 38], 29 | [13, 18, 23, 28, 33, 38, 38, 39], 30 | [14, 19, 24, 29, 34, 39, 39, 40], 31 | [15, 20, 25, 30, 35, 40, 40, 40], 32 | [15, 20, 25, 30, 35, 40, 40, 40], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_tensors( 68 | self, 69 | in_channels, 70 | depth, 71 | kernel_size_depth, 72 | kernel_size_hex, 73 | stride_depth, 74 | stride_hex, 75 | ): 76 | channel_dist = 1000 77 | depth_dist = 40 78 | depth_steps = int(np.ceil((depth - kernel_size_depth + 1) / stride_depth)) 79 | 80 | # input tensor 81 | array = self.get_array() 82 | array = np.expand_dims( 83 | np.stack( 84 | [ 85 | j * channel_dist 86 | + np.stack([i * depth_dist + array for i in range(depth)]) 87 | for j in range(in_channels) 88 | ] 89 | ), 90 | 0, 91 | ) 92 | tensor = torch.FloatTensor(array) 93 | 94 | # expected output tensor 95 | if kernel_size_hex == 1: 96 | pool2d_array = self.get_array_maxpool2d_size1_stride1() 97 | elif kernel_size_hex == 2: 98 | pool2d_array = self.get_array_maxpool2d_size2_stride1() 99 | if stride_hex == 2: 100 | pool2d_array = self.get_array_stride_2(pool2d_array) 101 | elif stride_hex == 3: 102 | pool2d_array = self.get_array_stride_3(pool2d_array) 103 | pooled_array = np.expand_dims( 104 | np.stack( 105 | [ 106 | channel * channel_dist 107 | + np.stack( 108 | [ 109 | (dstep * stride_depth + kernel_size_depth - 1) * depth_dist 110 | + pool2d_array 111 | for dstep in range(depth_steps) 112 | ] 113 | ) 114 | for channel in range(in_channels) 115 | ] 116 | ), 117 | 0, 118 | ) 119 | pooled_tensor = torch.FloatTensor(pooled_array) 120 | 121 | # output tensor of test method 122 | maxpool3d = hex.MaxPool3d( 123 | (kernel_size_depth, kernel_size_hex), (stride_depth, stride_hex) 124 | ) 125 | 126 | return maxpool3d(tensor), pooled_tensor 127 | 128 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_1(self): 129 | in_channels = 1 130 | depth = 1 131 | kernel_size_depth = 1 132 | kernel_size_hex = 1 133 | stride_depth = 1 134 | stride_hex = 1 135 | 136 | test_ouput, expectation = self.get_tensors( 137 | in_channels, 138 | depth, 139 | kernel_size_depth, 140 | kernel_size_hex, 141 | stride_depth, 142 | stride_hex, 143 | ) 144 | 145 | assert torch.equal(test_ouput, expectation) 146 | 147 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_2(self): 148 | in_channels = 1 149 | depth = 1 150 | kernel_size_depth = 1 151 | kernel_size_hex = 1 152 | stride_depth = 1 153 | stride_hex = 2 154 | 155 | test_ouput, expectation = self.get_tensors( 156 | in_channels, 157 | depth, 158 | kernel_size_depth, 159 | kernel_size_hex, 160 | stride_depth, 161 | stride_hex, 162 | ) 163 | 164 | assert torch.equal(test_ouput, expectation) 165 | 166 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_3(self): 167 | in_channels = 1 168 | depth = 1 169 | kernel_size_depth = 1 170 | kernel_size_hex = 1 171 | stride_depth = 1 172 | stride_hex = 3 173 | 174 | test_ouput, expectation = self.get_tensors( 175 | in_channels, 176 | depth, 177 | kernel_size_depth, 178 | kernel_size_hex, 179 | stride_depth, 180 | stride_hex, 181 | ) 182 | 183 | assert torch.equal(test_ouput, expectation) 184 | 185 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_1(self): 186 | in_channels = 1 187 | depth = 1 188 | kernel_size_depth = 1 189 | kernel_size_hex = 2 190 | stride_depth = 1 191 | stride_hex = 1 192 | 193 | test_ouput, expectation = self.get_tensors( 194 | in_channels, 195 | depth, 196 | kernel_size_depth, 197 | kernel_size_hex, 198 | stride_depth, 199 | stride_hex, 200 | ) 201 | 202 | assert torch.equal(test_ouput, expectation) 203 | 204 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_2(self): 205 | in_channels = 1 206 | depth = 1 207 | kernel_size_depth = 1 208 | kernel_size_hex = 2 209 | stride_depth = 1 210 | stride_hex = 2 211 | 212 | test_ouput, expectation = self.get_tensors( 213 | in_channels, 214 | depth, 215 | kernel_size_depth, 216 | kernel_size_hex, 217 | stride_depth, 218 | stride_hex, 219 | ) 220 | 221 | assert torch.equal(test_ouput, expectation) 222 | 223 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_3(self): 224 | in_channels = 1 225 | depth = 1 226 | kernel_size_depth = 1 227 | kernel_size_hex = 2 228 | stride_depth = 1 229 | stride_hex = 3 230 | 231 | test_ouput, expectation = self.get_tensors( 232 | in_channels, 233 | depth, 234 | kernel_size_depth, 235 | kernel_size_hex, 236 | stride_depth, 237 | stride_hex, 238 | ) 239 | 240 | assert torch.equal(test_ouput, expectation) 241 | 242 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_1_1(self): 243 | in_channels = 1 244 | depth = 9 245 | kernel_size_depth = 1 246 | kernel_size_hex = 1 247 | stride_depth = 1 248 | stride_hex = 1 249 | 250 | test_ouput, expectation = self.get_tensors( 251 | in_channels, 252 | depth, 253 | kernel_size_depth, 254 | kernel_size_hex, 255 | stride_depth, 256 | stride_hex, 257 | ) 258 | 259 | assert torch.equal(test_ouput, expectation) 260 | 261 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_2_1(self): 262 | in_channels = 1 263 | depth = 9 264 | kernel_size_depth = 1 265 | kernel_size_hex = 1 266 | stride_depth = 2 267 | stride_hex = 1 268 | 269 | test_ouput, expectation = self.get_tensors( 270 | in_channels, 271 | depth, 272 | kernel_size_depth, 273 | kernel_size_hex, 274 | stride_depth, 275 | stride_hex, 276 | ) 277 | 278 | assert torch.equal(test_ouput, expectation) 279 | 280 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_3_1(self): 281 | in_channels = 1 282 | depth = 9 283 | kernel_size_depth = 1 284 | kernel_size_hex = 1 285 | stride_depth = 3 286 | stride_hex = 1 287 | 288 | test_ouput, expectation = self.get_tensors( 289 | in_channels, 290 | depth, 291 | kernel_size_depth, 292 | kernel_size_hex, 293 | stride_depth, 294 | stride_hex, 295 | ) 296 | 297 | assert torch.equal(test_ouput, expectation) 298 | 299 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_1_1(self): 300 | in_channels = 1 301 | depth = 9 302 | kernel_size_depth = 2 303 | kernel_size_hex = 1 304 | stride_depth = 1 305 | stride_hex = 1 306 | 307 | test_ouput, expectation = self.get_tensors( 308 | in_channels, 309 | depth, 310 | kernel_size_depth, 311 | kernel_size_hex, 312 | stride_depth, 313 | stride_hex, 314 | ) 315 | 316 | assert torch.equal(test_ouput, expectation) 317 | 318 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_1(self): 319 | in_channels = 1 320 | depth = 9 321 | kernel_size_depth = 2 322 | kernel_size_hex = 1 323 | stride_depth = 2 324 | stride_hex = 1 325 | 326 | test_ouput, expectation = self.get_tensors( 327 | in_channels, 328 | depth, 329 | kernel_size_depth, 330 | kernel_size_hex, 331 | stride_depth, 332 | stride_hex, 333 | ) 334 | 335 | assert torch.equal(test_ouput, expectation) 336 | 337 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_2(self): 338 | in_channels = 1 339 | depth = 9 340 | kernel_size_depth = 2 341 | kernel_size_hex = 1 342 | stride_depth = 2 343 | stride_hex = 2 344 | 345 | test_ouput, expectation = self.get_tensors( 346 | in_channels, 347 | depth, 348 | kernel_size_depth, 349 | kernel_size_hex, 350 | stride_depth, 351 | stride_hex, 352 | ) 353 | 354 | assert torch.equal(test_ouput, expectation) 355 | 356 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_1(self): 357 | in_channels = 1 358 | depth = 9 359 | kernel_size_depth = 7 360 | kernel_size_hex = 2 361 | stride_depth = 1 362 | stride_hex = 1 363 | 364 | test_ouput, expectation = self.get_tensors( 365 | in_channels, 366 | depth, 367 | kernel_size_depth, 368 | kernel_size_hex, 369 | stride_depth, 370 | stride_hex, 371 | ) 372 | 373 | assert torch.equal(test_ouput, expectation) 374 | 375 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_2(self): 376 | in_channels = 1 377 | depth = 9 378 | kernel_size_depth = 7 379 | kernel_size_hex = 2 380 | stride_depth = 1 381 | stride_hex = 1 382 | 383 | test_ouput, expectation = self.get_tensors( 384 | in_channels, 385 | depth, 386 | kernel_size_depth, 387 | kernel_size_hex, 388 | stride_depth, 389 | stride_hex, 390 | ) 391 | 392 | assert torch.equal(test_ouput, expectation) 393 | 394 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_2_2(self): 395 | in_channels = 1 396 | depth = 9 397 | kernel_size_depth = 7 398 | kernel_size_hex = 2 399 | stride_depth = 2 400 | stride_hex = 1 401 | 402 | test_ouput, expectation = self.get_tensors( 403 | in_channels, 404 | depth, 405 | kernel_size_depth, 406 | kernel_size_hex, 407 | stride_depth, 408 | stride_hex, 409 | ) 410 | 411 | assert torch.equal(test_ouput, expectation) 412 | 413 | def test_in_channels_5_depth_9_kernel_size_3_2_stride_1_1(self): 414 | in_channels = 5 415 | depth = 9 416 | kernel_size_depth = 7 417 | kernel_size_hex = 2 418 | stride_depth = 1 419 | stride_hex = 1 420 | 421 | test_ouput, expectation = self.get_tensors( 422 | in_channels, 423 | depth, 424 | kernel_size_depth, 425 | kernel_size_hex, 426 | stride_depth, 427 | stride_hex, 428 | ) 429 | 430 | assert torch.equal(test_ouput, expectation) 431 | -------------------------------------------------------------------------------- /tests/test_Conv2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestConv2d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_conv2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [9, 39, 45, 99, 85, 159, 125, 136], 17 | [19, 51, 82, 121, 152, 191, 222, 176], 18 | [24, 58, 89, 128, 159, 198, 229, 181], 19 | [29, 65, 96, 135, 166, 205, 236, 186], 20 | [28, 39, 87, 79, 147, 119, 207, 114], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_conv2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [42, 96, 128, 219, 238, 349, 265, 260], 29 | [67, 141, 194, 312, 354, 492, 388, 361], 30 | [84, 162, 243, 346, 433, 536, 494, 408], 31 | [90, 145, 246, 302, 426, 462, 474, 343], 32 | [68, 104, 184, 213, 314, 323, 355, 245], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_n_neighbors_size1(self): 68 | return np.array( 69 | [ 70 | [3, 6, 4, 6, 4, 6, 4, 4], 71 | [5, 7, 7, 7, 7, 7, 7, 5], 72 | [5, 7, 7, 7, 7, 7, 7, 5], 73 | [5, 7, 7, 7, 7, 7, 7, 5], 74 | [4, 4, 6, 4, 6, 4, 6, 3], 75 | ], 76 | dtype=np.float32, 77 | ) 78 | 79 | def get_n_neighbors_size2(self): 80 | return np.array( 81 | [ 82 | [7, 11, 11, 13, 11, 13, 9, 8], 83 | [10, 15, 16, 18, 16, 18, 13, 11], 84 | [12, 16, 19, 19, 19, 19, 16, 12], 85 | [11, 13, 18, 16, 18, 16, 15, 10], 86 | [8, 9, 13, 11, 13, 11, 11, 7], 87 | ], 88 | dtype=np.float32, 89 | ) 90 | 91 | def get_tensors(self, in_channels, kernel_size, stride, bias): 92 | channel_dist = 1000 93 | if bias is False: 94 | bias_value = 0 95 | else: 96 | bias_value = 1.0 97 | 98 | # input tensor 99 | array = self.get_array() 100 | array = np.expand_dims( 101 | np.stack([j * channel_dist + array for j in range(in_channels)]), 0 102 | ) 103 | tensor = torch.FloatTensor(array) 104 | 105 | # expected output tensor 106 | if kernel_size == 1: 107 | conv2d_array = self.get_array_conv2d_size1_stride1() 108 | n_neighbours = self.get_n_neighbors_size1() 109 | elif kernel_size == 2: 110 | conv2d_array = self.get_array_conv2d_size2_stride1() 111 | n_neighbours = self.get_n_neighbors_size2() 112 | convolved_array = np.sum( 113 | np.stack( 114 | [ 115 | (channel * channel_dist) * n_neighbours + conv2d_array 116 | for channel in range(in_channels) 117 | ] 118 | ), 119 | 0, 120 | ) 121 | if stride == 2: 122 | convolved_array = self.get_array_stride_2(convolved_array) 123 | elif stride == 3: 124 | convolved_array = self.get_array_stride_3(convolved_array) 125 | convolved_array = np.expand_dims(np.expand_dims(convolved_array, 0), 0) 126 | convolved_tensor = torch.FloatTensor(convolved_array) + bias_value 127 | 128 | # output tensor of test method 129 | conv2d = hex.Conv2d(in_channels, 1, kernel_size, stride, bias, True) 130 | 131 | return conv2d(tensor), convolved_tensor 132 | 133 | def test_in_channels_1_kernel_size_1_stride_1_bias_False(self): 134 | in_channels = 1 135 | kernel_size = 1 136 | stride = 1 137 | bias = False 138 | 139 | test_ouput, expectation = self.get_tensors( 140 | in_channels, kernel_size, stride, bias 141 | ) 142 | 143 | assert torch.equal(test_ouput, expectation) 144 | 145 | def test_in_channels_1_kernel_size_1_stride_2_bias_False(self): 146 | in_channels = 1 147 | kernel_size = 1 148 | stride = 2 149 | bias = False 150 | 151 | test_ouput, expectation = self.get_tensors( 152 | in_channels, kernel_size, stride, bias 153 | ) 154 | 155 | assert torch.equal(test_ouput, expectation) 156 | 157 | def test_in_channels_1_kernel_size_1_stride_3_bias_False(self): 158 | in_channels = 1 159 | kernel_size = 1 160 | stride = 3 161 | bias = False 162 | 163 | test_ouput, expectation = self.get_tensors( 164 | in_channels, kernel_size, stride, bias 165 | ) 166 | 167 | assert torch.equal(test_ouput, expectation) 168 | 169 | def test_in_channels_1_kernel_size_2_stride_1_bias_False(self): 170 | in_channels = 1 171 | kernel_size = 2 172 | stride = 1 173 | bias = False 174 | 175 | test_ouput, expectation = self.get_tensors( 176 | in_channels, kernel_size, stride, bias 177 | ) 178 | 179 | assert torch.equal(test_ouput, expectation) 180 | 181 | def test_in_channels_1_kernel_size_2_stride_2_bias_False(self): 182 | in_channels = 1 183 | kernel_size = 2 184 | stride = 2 185 | bias = False 186 | 187 | test_ouput, expectation = self.get_tensors( 188 | in_channels, kernel_size, stride, bias 189 | ) 190 | 191 | assert torch.equal(test_ouput, expectation) 192 | 193 | def test_in_channels_1_kernel_size_2_stride_3_bias_False(self): 194 | in_channels = 1 195 | kernel_size = 2 196 | stride = 3 197 | bias = False 198 | 199 | test_ouput, expectation = self.get_tensors( 200 | in_channels, kernel_size, stride, bias 201 | ) 202 | 203 | assert torch.equal(test_ouput, expectation) 204 | 205 | def test_in_channels_5_kernel_size_1_stride_1_bias_False(self): 206 | in_channels = 5 207 | kernel_size = 1 208 | stride = 1 209 | bias = False 210 | 211 | test_ouput, expectation = self.get_tensors( 212 | in_channels, kernel_size, stride, bias 213 | ) 214 | 215 | assert torch.equal(test_ouput, expectation) 216 | 217 | def test_in_channels_5_kernel_size_1_stride_2_bias_False(self): 218 | in_channels = 5 219 | kernel_size = 1 220 | stride = 2 221 | bias = False 222 | 223 | test_ouput, expectation = self.get_tensors( 224 | in_channels, kernel_size, stride, bias 225 | ) 226 | 227 | assert torch.equal(test_ouput, expectation) 228 | 229 | def test_in_channels_5_kernel_size_1_stride_3_bias_False(self): 230 | in_channels = 5 231 | kernel_size = 1 232 | stride = 3 233 | bias = False 234 | 235 | test_ouput, expectation = self.get_tensors( 236 | in_channels, kernel_size, stride, bias 237 | ) 238 | 239 | assert torch.equal(test_ouput, expectation) 240 | 241 | def test_in_channels_5_kernel_size_2_stride_1_bias_False(self): 242 | in_channels = 5 243 | kernel_size = 2 244 | stride = 1 245 | bias = False 246 | 247 | test_ouput, expectation = self.get_tensors( 248 | in_channels, kernel_size, stride, bias 249 | ) 250 | 251 | assert torch.equal(test_ouput, expectation) 252 | 253 | def test_in_channels_5_kernel_size_2_stride_2_bias_False(self): 254 | in_channels = 5 255 | kernel_size = 2 256 | stride = 2 257 | bias = False 258 | 259 | test_ouput, expectation = self.get_tensors( 260 | in_channels, kernel_size, stride, bias 261 | ) 262 | 263 | assert torch.equal(test_ouput, expectation) 264 | 265 | def test_in_channels_5_kernel_size_2_stride_3_bias_False(self): 266 | in_channels = 5 267 | kernel_size = 2 268 | stride = 3 269 | bias = False 270 | 271 | test_ouput, expectation = self.get_tensors( 272 | in_channels, kernel_size, stride, bias 273 | ) 274 | 275 | assert torch.equal(test_ouput, expectation) 276 | 277 | def test_in_channels_1_kernel_size_1_stride_1_bias_True(self): 278 | in_channels = 1 279 | kernel_size = 1 280 | stride = 1 281 | bias = True 282 | 283 | test_ouput, expectation = self.get_tensors( 284 | in_channels, kernel_size, stride, bias 285 | ) 286 | 287 | assert torch.equal(test_ouput, expectation) 288 | 289 | def test_in_channels_1_kernel_size_1_stride_2_bias_True(self): 290 | in_channels = 1 291 | kernel_size = 1 292 | stride = 2 293 | bias = True 294 | 295 | test_ouput, expectation = self.get_tensors( 296 | in_channels, kernel_size, stride, bias 297 | ) 298 | 299 | assert torch.equal(test_ouput, expectation) 300 | 301 | def test_in_channels_1_kernel_size_1_stride_3_bias_True(self): 302 | in_channels = 1 303 | kernel_size = 1 304 | stride = 3 305 | bias = True 306 | 307 | test_ouput, expectation = self.get_tensors( 308 | in_channels, kernel_size, stride, bias 309 | ) 310 | 311 | assert torch.equal(test_ouput, expectation) 312 | 313 | def test_in_channels_1_kernel_size_2_stride_1_bias_True(self): 314 | in_channels = 1 315 | kernel_size = 2 316 | stride = 1 317 | bias = True 318 | 319 | test_ouput, expectation = self.get_tensors( 320 | in_channels, kernel_size, stride, bias 321 | ) 322 | 323 | assert torch.equal(test_ouput, expectation) 324 | 325 | def test_in_channels_1_kernel_size_2_stride_2_bias_True(self): 326 | in_channels = 1 327 | kernel_size = 2 328 | stride = 2 329 | bias = True 330 | 331 | test_ouput, expectation = self.get_tensors( 332 | in_channels, kernel_size, stride, bias 333 | ) 334 | 335 | assert torch.equal(test_ouput, expectation) 336 | 337 | def test_in_channels_1_kernel_size_2_stride_3_bias_True(self): 338 | in_channels = 1 339 | kernel_size = 2 340 | stride = 3 341 | bias = True 342 | 343 | test_ouput, expectation = self.get_tensors( 344 | in_channels, kernel_size, stride, bias 345 | ) 346 | 347 | assert torch.equal(test_ouput, expectation) 348 | 349 | def test_in_channels_5_kernel_size_1_stride_1_bias_True(self): 350 | in_channels = 5 351 | kernel_size = 1 352 | stride = 1 353 | bias = True 354 | 355 | test_ouput, expectation = self.get_tensors( 356 | in_channels, kernel_size, stride, bias 357 | ) 358 | 359 | assert torch.equal(test_ouput, expectation) 360 | 361 | def test_in_channels_5_kernel_size_1_stride_2_bias_True(self): 362 | in_channels = 5 363 | kernel_size = 1 364 | stride = 2 365 | bias = True 366 | 367 | test_ouput, expectation = self.get_tensors( 368 | in_channels, kernel_size, stride, bias 369 | ) 370 | 371 | assert torch.equal(test_ouput, expectation) 372 | 373 | def test_in_channels_5_kernel_size_1_stride_3_bias_True(self): 374 | in_channels = 5 375 | kernel_size = 1 376 | stride = 3 377 | bias = True 378 | 379 | test_ouput, expectation = self.get_tensors( 380 | in_channels, kernel_size, stride, bias 381 | ) 382 | 383 | assert torch.equal(test_ouput, expectation) 384 | 385 | def test_in_channels_5_kernel_size_2_stride_1_bias_True(self): 386 | in_channels = 5 387 | kernel_size = 2 388 | stride = 1 389 | bias = True 390 | 391 | test_ouput, expectation = self.get_tensors( 392 | in_channels, kernel_size, stride, bias 393 | ) 394 | 395 | assert torch.equal(test_ouput, expectation) 396 | 397 | def test_in_channels_5_kernel_size_2_stride_2_bias_True(self): 398 | in_channels = 5 399 | kernel_size = 2 400 | stride = 2 401 | bias = True 402 | 403 | test_ouput, expectation = self.get_tensors( 404 | in_channels, kernel_size, stride, bias 405 | ) 406 | 407 | assert torch.equal(test_ouput, expectation) 408 | 409 | def test_in_channels_5_kernel_size_2_stride_3_bias_True(self): 410 | in_channels = 5 411 | kernel_size = 2 412 | stride = 3 413 | bias = True 414 | 415 | test_ouput, expectation = self.get_tensors( 416 | in_channels, kernel_size, stride, bias 417 | ) 418 | 419 | assert torch.equal(test_ouput, expectation) 420 | -------------------------------------------------------------------------------- /tests/test_Conv2d_CustomKernel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestConv2d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_conv2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [9, 39, 45, 99, 85, 159, 125, 136], 17 | [19, 51, 82, 121, 152, 191, 222, 176], 18 | [24, 58, 89, 128, 159, 198, 229, 181], 19 | [29, 65, 96, 135, 166, 205, 236, 186], 20 | [28, 39, 87, 79, 147, 119, 207, 114], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_conv2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [42, 96, 128, 219, 238, 349, 265, 260], 29 | [67, 141, 194, 312, 354, 492, 388, 361], 30 | [84, 162, 243, 346, 433, 536, 494, 408], 31 | [90, 145, 246, 302, 426, 462, 474, 343], 32 | [68, 104, 184, 213, 314, 323, 355, 245], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_n_neighbors_size1(self): 68 | return np.array( 69 | [ 70 | [3, 6, 4, 6, 4, 6, 4, 4], 71 | [5, 7, 7, 7, 7, 7, 7, 5], 72 | [5, 7, 7, 7, 7, 7, 7, 5], 73 | [5, 7, 7, 7, 7, 7, 7, 5], 74 | [4, 4, 6, 4, 6, 4, 6, 3], 75 | ], 76 | dtype=np.float32, 77 | ) 78 | 79 | def get_n_neighbors_size2(self): 80 | return np.array( 81 | [ 82 | [7, 11, 11, 13, 11, 13, 9, 8], 83 | [10, 15, 16, 18, 16, 18, 13, 11], 84 | [12, 16, 19, 19, 19, 19, 16, 12], 85 | [11, 13, 18, 16, 18, 16, 15, 10], 86 | [8, 9, 13, 11, 13, 11, 11, 7], 87 | ], 88 | dtype=np.float32, 89 | ) 90 | 91 | def get_tensors(self, in_channels, kernel_size, stride, bias_bool): 92 | channel_dist = 1000 93 | if bias_bool is False: 94 | bias_value = 0 95 | bias = None 96 | else: 97 | bias_value = 1.0 98 | bias = np.array([1]) 99 | 100 | # input tensor 101 | array = self.get_array() 102 | array = np.expand_dims( 103 | np.stack([j * channel_dist + array for j in range(in_channels)]), 0 104 | ) 105 | tensor = torch.FloatTensor(array) 106 | 107 | # expected output tensor 108 | if kernel_size == 1: 109 | conv2d_array = self.get_array_conv2d_size1_stride1() 110 | n_neighbours = self.get_n_neighbors_size1() 111 | elif kernel_size == 2: 112 | conv2d_array = self.get_array_conv2d_size2_stride1() 113 | n_neighbours = self.get_n_neighbors_size2() 114 | convolved_array = np.sum( 115 | np.stack( 116 | [ 117 | (channel * channel_dist) * n_neighbours + conv2d_array 118 | for channel in range(in_channels) 119 | ] 120 | ), 121 | 0, 122 | ) 123 | if stride == 2: 124 | convolved_array = self.get_array_stride_2(convolved_array) 125 | elif stride == 3: 126 | convolved_array = self.get_array_stride_3(convolved_array) 127 | convolved_array = np.expand_dims(np.expand_dims(convolved_array, 0), 0) 128 | convolved_tensor = torch.FloatTensor(convolved_array) + bias_value 129 | 130 | # output tensor of test method 131 | if kernel_size == 1: 132 | kernel = [np.ones((1, in_channels, 3, 1)), np.ones((1, in_channels, 2, 2))] 133 | elif kernel_size == 2: 134 | kernel = [ 135 | np.ones((1, in_channels, 5, 1)), 136 | np.ones((1, in_channels, 4, 2)), 137 | np.ones((1, in_channels, 3, 2)), 138 | ] 139 | conv2d = hex.Conv2d_CustomKernel(kernel, stride, bias) 140 | 141 | return conv2d(tensor), convolved_tensor 142 | 143 | def test_in_channels_1_kernel_size_1_stride_1_bias_False(self): 144 | in_channels = 1 145 | kernel_size = 1 146 | stride = 1 147 | bias = False 148 | 149 | test_ouput, expectation = self.get_tensors( 150 | in_channels, kernel_size, stride, bias 151 | ) 152 | 153 | assert torch.equal(test_ouput, expectation) 154 | 155 | def test_in_channels_1_kernel_size_1_stride_2_bias_False(self): 156 | in_channels = 1 157 | kernel_size = 1 158 | stride = 2 159 | bias = False 160 | 161 | test_ouput, expectation = self.get_tensors( 162 | in_channels, kernel_size, stride, bias 163 | ) 164 | 165 | assert torch.equal(test_ouput, expectation) 166 | 167 | def test_in_channels_1_kernel_size_1_stride_3_bias_False(self): 168 | in_channels = 1 169 | kernel_size = 1 170 | stride = 3 171 | bias = False 172 | 173 | test_ouput, expectation = self.get_tensors( 174 | in_channels, kernel_size, stride, bias 175 | ) 176 | 177 | assert torch.equal(test_ouput, expectation) 178 | 179 | def test_in_channels_1_kernel_size_2_stride_1_bias_False(self): 180 | in_channels = 1 181 | kernel_size = 2 182 | stride = 1 183 | bias = False 184 | 185 | test_ouput, expectation = self.get_tensors( 186 | in_channels, kernel_size, stride, bias 187 | ) 188 | 189 | assert torch.equal(test_ouput, expectation) 190 | 191 | def test_in_channels_1_kernel_size_2_stride_2_bias_False(self): 192 | in_channels = 1 193 | kernel_size = 2 194 | stride = 2 195 | bias = False 196 | 197 | test_ouput, expectation = self.get_tensors( 198 | in_channels, kernel_size, stride, bias 199 | ) 200 | 201 | assert torch.equal(test_ouput, expectation) 202 | 203 | def test_in_channels_1_kernel_size_2_stride_3_bias_False(self): 204 | in_channels = 1 205 | kernel_size = 2 206 | stride = 3 207 | bias = False 208 | 209 | test_ouput, expectation = self.get_tensors( 210 | in_channels, kernel_size, stride, bias 211 | ) 212 | 213 | assert torch.equal(test_ouput, expectation) 214 | 215 | def test_in_channels_5_kernel_size_1_stride_1_bias_False(self): 216 | in_channels = 5 217 | kernel_size = 1 218 | stride = 1 219 | bias = False 220 | 221 | test_ouput, expectation = self.get_tensors( 222 | in_channels, kernel_size, stride, bias 223 | ) 224 | 225 | assert torch.equal(test_ouput, expectation) 226 | 227 | def test_in_channels_5_kernel_size_1_stride_2_bias_False(self): 228 | in_channels = 5 229 | kernel_size = 1 230 | stride = 2 231 | bias = False 232 | 233 | test_ouput, expectation = self.get_tensors( 234 | in_channels, kernel_size, stride, bias 235 | ) 236 | 237 | assert torch.equal(test_ouput, expectation) 238 | 239 | def test_in_channels_5_kernel_size_1_stride_3_bias_False(self): 240 | in_channels = 5 241 | kernel_size = 1 242 | stride = 3 243 | bias = False 244 | 245 | test_ouput, expectation = self.get_tensors( 246 | in_channels, kernel_size, stride, bias 247 | ) 248 | 249 | assert torch.equal(test_ouput, expectation) 250 | 251 | def test_in_channels_5_kernel_size_2_stride_1_bias_False(self): 252 | in_channels = 5 253 | kernel_size = 2 254 | stride = 1 255 | bias = False 256 | 257 | test_ouput, expectation = self.get_tensors( 258 | in_channels, kernel_size, stride, bias 259 | ) 260 | 261 | assert torch.equal(test_ouput, expectation) 262 | 263 | def test_in_channels_5_kernel_size_2_stride_2_bias_False(self): 264 | in_channels = 5 265 | kernel_size = 2 266 | stride = 2 267 | bias = False 268 | 269 | test_ouput, expectation = self.get_tensors( 270 | in_channels, kernel_size, stride, bias 271 | ) 272 | 273 | assert torch.equal(test_ouput, expectation) 274 | 275 | def test_in_channels_5_kernel_size_2_stride_3_bias_False(self): 276 | in_channels = 5 277 | kernel_size = 2 278 | stride = 3 279 | bias = False 280 | 281 | test_ouput, expectation = self.get_tensors( 282 | in_channels, kernel_size, stride, bias 283 | ) 284 | 285 | assert torch.equal(test_ouput, expectation) 286 | 287 | def test_in_channels_1_kernel_size_1_stride_1_bias_True(self): 288 | in_channels = 1 289 | kernel_size = 1 290 | stride = 1 291 | bias = True 292 | 293 | test_ouput, expectation = self.get_tensors( 294 | in_channels, kernel_size, stride, bias 295 | ) 296 | 297 | assert torch.equal(test_ouput, expectation) 298 | 299 | def test_in_channels_1_kernel_size_1_stride_2_bias_True(self): 300 | in_channels = 1 301 | kernel_size = 1 302 | stride = 2 303 | bias = True 304 | 305 | test_ouput, expectation = self.get_tensors( 306 | in_channels, kernel_size, stride, bias 307 | ) 308 | 309 | assert torch.equal(test_ouput, expectation) 310 | 311 | def test_in_channels_1_kernel_size_1_stride_3_bias_True(self): 312 | in_channels = 1 313 | kernel_size = 1 314 | stride = 3 315 | bias = True 316 | 317 | test_ouput, expectation = self.get_tensors( 318 | in_channels, kernel_size, stride, bias 319 | ) 320 | 321 | assert torch.equal(test_ouput, expectation) 322 | 323 | def test_in_channels_1_kernel_size_2_stride_1_bias_True(self): 324 | in_channels = 1 325 | kernel_size = 2 326 | stride = 1 327 | bias = True 328 | 329 | test_ouput, expectation = self.get_tensors( 330 | in_channels, kernel_size, stride, bias 331 | ) 332 | 333 | assert torch.equal(test_ouput, expectation) 334 | 335 | def test_in_channels_1_kernel_size_2_stride_2_bias_True(self): 336 | in_channels = 1 337 | kernel_size = 2 338 | stride = 2 339 | bias = True 340 | 341 | test_ouput, expectation = self.get_tensors( 342 | in_channels, kernel_size, stride, bias 343 | ) 344 | 345 | assert torch.equal(test_ouput, expectation) 346 | 347 | def test_in_channels_1_kernel_size_2_stride_3_bias_True(self): 348 | in_channels = 1 349 | kernel_size = 2 350 | stride = 3 351 | bias = True 352 | 353 | test_ouput, expectation = self.get_tensors( 354 | in_channels, kernel_size, stride, bias 355 | ) 356 | 357 | assert torch.equal(test_ouput, expectation) 358 | 359 | def test_in_channels_5_kernel_size_1_stride_1_bias_True(self): 360 | in_channels = 5 361 | kernel_size = 1 362 | stride = 1 363 | bias = True 364 | 365 | test_ouput, expectation = self.get_tensors( 366 | in_channels, kernel_size, stride, bias 367 | ) 368 | 369 | assert torch.equal(test_ouput, expectation) 370 | 371 | def test_in_channels_5_kernel_size_1_stride_2_bias_True(self): 372 | in_channels = 5 373 | kernel_size = 1 374 | stride = 2 375 | bias = True 376 | 377 | test_ouput, expectation = self.get_tensors( 378 | in_channels, kernel_size, stride, bias 379 | ) 380 | 381 | assert torch.equal(test_ouput, expectation) 382 | 383 | def test_in_channels_5_kernel_size_1_stride_3_bias_True(self): 384 | in_channels = 5 385 | kernel_size = 1 386 | stride = 3 387 | bias = True 388 | 389 | test_ouput, expectation = self.get_tensors( 390 | in_channels, kernel_size, stride, bias 391 | ) 392 | 393 | assert torch.equal(test_ouput, expectation) 394 | 395 | def test_in_channels_5_kernel_size_2_stride_1_bias_True(self): 396 | in_channels = 5 397 | kernel_size = 2 398 | stride = 1 399 | bias = True 400 | 401 | test_ouput, expectation = self.get_tensors( 402 | in_channels, kernel_size, stride, bias 403 | ) 404 | 405 | assert torch.equal(test_ouput, expectation) 406 | 407 | def test_in_channels_5_kernel_size_2_stride_2_bias_True(self): 408 | in_channels = 5 409 | kernel_size = 2 410 | stride = 2 411 | bias = True 412 | 413 | test_ouput, expectation = self.get_tensors( 414 | in_channels, kernel_size, stride, bias 415 | ) 416 | 417 | assert torch.equal(test_ouput, expectation) 418 | 419 | def test_in_channels_5_kernel_size_2_stride_3_bias_True(self): 420 | in_channels = 5 421 | kernel_size = 2 422 | stride = 3 423 | bias = True 424 | 425 | test_ouput, expectation = self.get_tensors( 426 | in_channels, kernel_size, stride, bias 427 | ) 428 | 429 | assert torch.equal(test_ouput, expectation) 430 | -------------------------------------------------------------------------------- /tests/test_Conv3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestConv3d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_conv2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [9, 39, 45, 99, 85, 159, 125, 136], 17 | [19, 51, 82, 121, 152, 191, 222, 176], 18 | [24, 58, 89, 128, 159, 198, 229, 181], 19 | [29, 65, 96, 135, 166, 205, 236, 186], 20 | [28, 39, 87, 79, 147, 119, 207, 114], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_conv2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [42, 96, 128, 219, 238, 349, 265, 260], 29 | [67, 141, 194, 312, 354, 492, 388, 361], 30 | [84, 162, 243, 346, 433, 536, 494, 408], 31 | [90, 145, 246, 302, 426, 462, 474, 343], 32 | [68, 104, 184, 213, 314, 323, 355, 245], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_n_neighbors_size1(self): 68 | return np.array( 69 | [ 70 | [3, 6, 4, 6, 4, 6, 4, 4], 71 | [5, 7, 7, 7, 7, 7, 7, 5], 72 | [5, 7, 7, 7, 7, 7, 7, 5], 73 | [5, 7, 7, 7, 7, 7, 7, 5], 74 | [4, 4, 6, 4, 6, 4, 6, 3], 75 | ], 76 | dtype=np.float32, 77 | ) 78 | 79 | def get_n_neighbors_size2(self): 80 | return np.array( 81 | [ 82 | [7, 11, 11, 13, 11, 13, 9, 8], 83 | [10, 15, 16, 18, 16, 18, 13, 11], 84 | [12, 16, 19, 19, 19, 19, 16, 12], 85 | [11, 13, 18, 16, 18, 16, 15, 10], 86 | [8, 9, 13, 11, 13, 11, 11, 7], 87 | ], 88 | dtype=np.float32, 89 | ) 90 | 91 | def get_tensors( 92 | self, 93 | in_channels, 94 | depth, 95 | kernel_size_depth, 96 | kernel_size_hex, 97 | stride_depth, 98 | stride_hex, 99 | bias, 100 | ): 101 | channel_dist = 1000 102 | depth_dist = 40 103 | depth_steps = int(np.ceil((depth - kernel_size_depth + 1) / stride_depth)) 104 | if bias is False: 105 | bias_value = 0 106 | else: 107 | bias_value = 1.0 108 | 109 | # input tensor 110 | array = self.get_array() 111 | array = np.expand_dims( 112 | np.stack( 113 | [ 114 | j * channel_dist 115 | + np.stack([i * depth_dist + array for i in range(depth)]) 116 | for j in range(in_channels) 117 | ] 118 | ), 119 | 0, 120 | ) 121 | tensor = torch.FloatTensor(array) 122 | 123 | # expected output tensor 124 | if kernel_size_hex == 1: 125 | conv2d_array = self.get_array_conv2d_size1_stride1() 126 | n_neighbours = self.get_n_neighbors_size1() 127 | elif kernel_size_hex == 2: 128 | conv2d_array = self.get_array_conv2d_size2_stride1() 129 | n_neighbours = self.get_n_neighbors_size2() 130 | convolved_array = [] 131 | for dstep in range(depth_steps): 132 | layer_array = np.sum( 133 | np.stack( 134 | [ 135 | ( 136 | channel * channel_dist 137 | + ((dstep * stride_depth) + dsize) * depth_dist 138 | ) 139 | * n_neighbours 140 | + conv2d_array 141 | for dsize in range(kernel_size_depth) 142 | for channel in range(in_channels) 143 | ] 144 | ), 145 | 0, 146 | ) 147 | if stride_hex == 2: 148 | layer_array = self.get_array_stride_2(layer_array) 149 | elif stride_hex == 3: 150 | layer_array = self.get_array_stride_3(layer_array) 151 | convolved_array.append(layer_array) 152 | convolved_array = np.expand_dims( 153 | np.expand_dims(np.stack(convolved_array), 0), 0 154 | ) 155 | convolved_tensor = torch.FloatTensor(convolved_array) + bias_value 156 | 157 | # output tensor of test method 158 | conv3d = hex.Conv3d( 159 | in_channels, 160 | 1, 161 | (kernel_size_depth, kernel_size_hex), 162 | (stride_depth, stride_hex), 163 | bias, 164 | True, 165 | ) 166 | 167 | return conv3d(tensor), convolved_tensor 168 | 169 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_1_bias_False(self): 170 | in_channels = 1 171 | depth = 1 172 | kernel_size_depth = 1 173 | kernel_size_hex = 1 174 | stride_depth = 1 175 | stride_hex = 1 176 | bias = False 177 | 178 | test_ouput, expectation = self.get_tensors( 179 | in_channels, 180 | depth, 181 | kernel_size_depth, 182 | kernel_size_hex, 183 | stride_depth, 184 | stride_hex, 185 | bias, 186 | ) 187 | 188 | assert torch.equal(test_ouput, expectation) 189 | 190 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_2_bias_False(self): 191 | in_channels = 1 192 | depth = 1 193 | kernel_size_depth = 1 194 | kernel_size_hex = 1 195 | stride_depth = 1 196 | stride_hex = 2 197 | bias = False 198 | 199 | test_ouput, expectation = self.get_tensors( 200 | in_channels, 201 | depth, 202 | kernel_size_depth, 203 | kernel_size_hex, 204 | stride_depth, 205 | stride_hex, 206 | bias, 207 | ) 208 | 209 | assert torch.equal(test_ouput, expectation) 210 | 211 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_3_bias_False(self): 212 | in_channels = 1 213 | depth = 1 214 | kernel_size_depth = 1 215 | kernel_size_hex = 1 216 | stride_depth = 1 217 | stride_hex = 3 218 | bias = False 219 | 220 | test_ouput, expectation = self.get_tensors( 221 | in_channels, 222 | depth, 223 | kernel_size_depth, 224 | kernel_size_hex, 225 | stride_depth, 226 | stride_hex, 227 | bias, 228 | ) 229 | 230 | assert torch.equal(test_ouput, expectation) 231 | 232 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_1_bias_False(self): 233 | in_channels = 1 234 | depth = 1 235 | kernel_size_depth = 1 236 | kernel_size_hex = 2 237 | stride_depth = 1 238 | stride_hex = 1 239 | bias = False 240 | 241 | test_ouput, expectation = self.get_tensors( 242 | in_channels, 243 | depth, 244 | kernel_size_depth, 245 | kernel_size_hex, 246 | stride_depth, 247 | stride_hex, 248 | bias, 249 | ) 250 | 251 | assert torch.equal(test_ouput, expectation) 252 | 253 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_2_bias_False(self): 254 | in_channels = 1 255 | depth = 1 256 | kernel_size_depth = 1 257 | kernel_size_hex = 2 258 | stride_depth = 1 259 | stride_hex = 2 260 | bias = False 261 | 262 | test_ouput, expectation = self.get_tensors( 263 | in_channels, 264 | depth, 265 | kernel_size_depth, 266 | kernel_size_hex, 267 | stride_depth, 268 | stride_hex, 269 | bias, 270 | ) 271 | 272 | assert torch.equal(test_ouput, expectation) 273 | 274 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_3_bias_False(self): 275 | in_channels = 1 276 | depth = 1 277 | kernel_size_depth = 1 278 | kernel_size_hex = 2 279 | stride_depth = 1 280 | stride_hex = 3 281 | bias = False 282 | 283 | test_ouput, expectation = self.get_tensors( 284 | in_channels, 285 | depth, 286 | kernel_size_depth, 287 | kernel_size_hex, 288 | stride_depth, 289 | stride_hex, 290 | bias, 291 | ) 292 | 293 | assert torch.equal(test_ouput, expectation) 294 | 295 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_1_1_bias_False(self): 296 | in_channels = 1 297 | depth = 9 298 | kernel_size_depth = 1 299 | kernel_size_hex = 1 300 | stride_depth = 1 301 | stride_hex = 1 302 | bias = False 303 | 304 | test_ouput, expectation = self.get_tensors( 305 | in_channels, 306 | depth, 307 | kernel_size_depth, 308 | kernel_size_hex, 309 | stride_depth, 310 | stride_hex, 311 | bias, 312 | ) 313 | 314 | assert torch.equal(test_ouput, expectation) 315 | 316 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_2_1_bias_False(self): 317 | in_channels = 1 318 | depth = 9 319 | kernel_size_depth = 1 320 | kernel_size_hex = 1 321 | stride_depth = 2 322 | stride_hex = 1 323 | bias = False 324 | 325 | test_ouput, expectation = self.get_tensors( 326 | in_channels, 327 | depth, 328 | kernel_size_depth, 329 | kernel_size_hex, 330 | stride_depth, 331 | stride_hex, 332 | bias, 333 | ) 334 | 335 | assert torch.equal(test_ouput, expectation) 336 | 337 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_3_1_bias_False(self): 338 | in_channels = 1 339 | depth = 9 340 | kernel_size_depth = 1 341 | kernel_size_hex = 1 342 | stride_depth = 3 343 | stride_hex = 1 344 | bias = False 345 | 346 | test_ouput, expectation = self.get_tensors( 347 | in_channels, 348 | depth, 349 | kernel_size_depth, 350 | kernel_size_hex, 351 | stride_depth, 352 | stride_hex, 353 | bias, 354 | ) 355 | 356 | assert torch.equal(test_ouput, expectation) 357 | 358 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_1_1_bias_False(self): 359 | in_channels = 1 360 | depth = 9 361 | kernel_size_depth = 2 362 | kernel_size_hex = 1 363 | stride_depth = 1 364 | stride_hex = 1 365 | bias = False 366 | 367 | test_ouput, expectation = self.get_tensors( 368 | in_channels, 369 | depth, 370 | kernel_size_depth, 371 | kernel_size_hex, 372 | stride_depth, 373 | stride_hex, 374 | bias, 375 | ) 376 | 377 | assert torch.equal(test_ouput, expectation) 378 | 379 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_1_bias_False(self): 380 | in_channels = 1 381 | depth = 9 382 | kernel_size_depth = 2 383 | kernel_size_hex = 1 384 | stride_depth = 2 385 | stride_hex = 1 386 | bias = False 387 | 388 | test_ouput, expectation = self.get_tensors( 389 | in_channels, 390 | depth, 391 | kernel_size_depth, 392 | kernel_size_hex, 393 | stride_depth, 394 | stride_hex, 395 | bias, 396 | ) 397 | 398 | assert torch.equal(test_ouput, expectation) 399 | 400 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_2_bias_False(self): 401 | in_channels = 1 402 | depth = 9 403 | kernel_size_depth = 2 404 | kernel_size_hex = 1 405 | stride_depth = 2 406 | stride_hex = 2 407 | bias = False 408 | 409 | test_ouput, expectation = self.get_tensors( 410 | in_channels, 411 | depth, 412 | kernel_size_depth, 413 | kernel_size_hex, 414 | stride_depth, 415 | stride_hex, 416 | bias, 417 | ) 418 | 419 | assert torch.equal(test_ouput, expectation) 420 | 421 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_1_bias_False(self): 422 | in_channels = 1 423 | depth = 9 424 | kernel_size_depth = 7 425 | kernel_size_hex = 2 426 | stride_depth = 1 427 | stride_hex = 1 428 | bias = False 429 | 430 | test_ouput, expectation = self.get_tensors( 431 | in_channels, 432 | depth, 433 | kernel_size_depth, 434 | kernel_size_hex, 435 | stride_depth, 436 | stride_hex, 437 | bias, 438 | ) 439 | 440 | assert torch.equal(test_ouput, expectation) 441 | 442 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_2_bias_False(self): 443 | in_channels = 1 444 | depth = 9 445 | kernel_size_depth = 7 446 | kernel_size_hex = 2 447 | stride_depth = 1 448 | stride_hex = 1 449 | bias = False 450 | 451 | test_ouput, expectation = self.get_tensors( 452 | in_channels, 453 | depth, 454 | kernel_size_depth, 455 | kernel_size_hex, 456 | stride_depth, 457 | stride_hex, 458 | bias, 459 | ) 460 | 461 | assert torch.equal(test_ouput, expectation) 462 | 463 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_2_2_bias_False(self): 464 | in_channels = 1 465 | depth = 9 466 | kernel_size_depth = 7 467 | kernel_size_hex = 2 468 | stride_depth = 2 469 | stride_hex = 1 470 | bias = False 471 | 472 | test_ouput, expectation = self.get_tensors( 473 | in_channels, 474 | depth, 475 | kernel_size_depth, 476 | kernel_size_hex, 477 | stride_depth, 478 | stride_hex, 479 | bias, 480 | ) 481 | 482 | assert torch.equal(test_ouput, expectation) 483 | 484 | def test_in_channels_5_depth_9_kernel_size_3_2_stride_1_1_bias_False(self): 485 | in_channels = 5 486 | depth = 9 487 | kernel_size_depth = 7 488 | kernel_size_hex = 2 489 | stride_depth = 1 490 | stride_hex = 1 491 | bias = False 492 | 493 | test_ouput, expectation = self.get_tensors( 494 | in_channels, 495 | depth, 496 | kernel_size_depth, 497 | kernel_size_hex, 498 | stride_depth, 499 | stride_hex, 500 | bias, 501 | ) 502 | 503 | assert torch.equal(test_ouput, expectation) 504 | 505 | def test_in_channels_5_depth_9_kernel_size_3_2_stride_1_1_bias_True(self): 506 | in_channels = 5 507 | depth = 9 508 | kernel_size_depth = 7 509 | kernel_size_hex = 2 510 | stride_depth = 1 511 | stride_hex = 1 512 | bias = True 513 | 514 | test_ouput, expectation = self.get_tensors( 515 | in_channels, 516 | depth, 517 | kernel_size_depth, 518 | kernel_size_hex, 519 | stride_depth, 520 | stride_hex, 521 | bias, 522 | ) 523 | 524 | assert torch.equal(test_ouput, expectation) 525 | -------------------------------------------------------------------------------- /tests/test_Conv3d_CustomKernel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hexagdly as hex 4 | import pytest 5 | 6 | 7 | class TestConv3d(object): 8 | def get_array(self): 9 | return np.array( 10 | [[j * 5 + 1 + i for j in range(8)] for i in range(5)], dtype=np.float32 11 | ) 12 | 13 | def get_array_conv2d_size1_stride1(self): 14 | return np.array( 15 | [ 16 | [9, 39, 45, 99, 85, 159, 125, 136], 17 | [19, 51, 82, 121, 152, 191, 222, 176], 18 | [24, 58, 89, 128, 159, 198, 229, 181], 19 | [29, 65, 96, 135, 166, 205, 236, 186], 20 | [28, 39, 87, 79, 147, 119, 207, 114], 21 | ], 22 | dtype=np.float32, 23 | ) 24 | 25 | def get_array_conv2d_size2_stride1(self): 26 | return np.array( 27 | [ 28 | [42, 96, 128, 219, 238, 349, 265, 260], 29 | [67, 141, 194, 312, 354, 492, 388, 361], 30 | [84, 162, 243, 346, 433, 536, 494, 408], 31 | [90, 145, 246, 302, 426, 462, 474, 343], 32 | [68, 104, 184, 213, 314, 323, 355, 245], 33 | ], 34 | dtype=np.float32, 35 | ) 36 | 37 | def get_array_stride_2(self, array_stride_1): 38 | array_stride_2 = np.zeros((2, 4), dtype=np.float32) 39 | stride_2_pos = [ 40 | (0, 0, 0, 0), 41 | (0, 1, 1, 2), 42 | (0, 2, 0, 4), 43 | (0, 3, 1, 6), 44 | (1, 0, 2, 0), 45 | (1, 1, 3, 2), 46 | (1, 2, 2, 4), 47 | (1, 3, 3, 6), 48 | ] 49 | for pos in stride_2_pos: 50 | array_stride_2[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 51 | return array_stride_2 52 | 53 | def get_array_stride_3(self, array_stride_1): 54 | array_stride_3 = np.zeros((2, 3), dtype=np.float32) 55 | stride_3_pos = [ 56 | (0, 0, 0, 0), 57 | (0, 1, 1, 3), 58 | (0, 2, 0, 6), 59 | (1, 0, 3, 0), 60 | (1, 1, 4, 3), 61 | (1, 2, 3, 6), 62 | ] 63 | for pos in stride_3_pos: 64 | array_stride_3[pos[0], pos[1]] = array_stride_1[pos[2], pos[3]] 65 | return array_stride_3 66 | 67 | def get_n_neighbors_size1(self): 68 | return np.array( 69 | [ 70 | [3, 6, 4, 6, 4, 6, 4, 4], 71 | [5, 7, 7, 7, 7, 7, 7, 5], 72 | [5, 7, 7, 7, 7, 7, 7, 5], 73 | [5, 7, 7, 7, 7, 7, 7, 5], 74 | [4, 4, 6, 4, 6, 4, 6, 3], 75 | ], 76 | dtype=np.float32, 77 | ) 78 | 79 | def get_n_neighbors_size2(self): 80 | return np.array( 81 | [ 82 | [7, 11, 11, 13, 11, 13, 9, 8], 83 | [10, 15, 16, 18, 16, 18, 13, 11], 84 | [12, 16, 19, 19, 19, 19, 16, 12], 85 | [11, 13, 18, 16, 18, 16, 15, 10], 86 | [8, 9, 13, 11, 13, 11, 11, 7], 87 | ], 88 | dtype=np.float32, 89 | ) 90 | 91 | def get_tensors( 92 | self, 93 | in_channels, 94 | depth, 95 | kernel_size_depth, 96 | kernel_size_hex, 97 | stride_depth, 98 | stride_hex, 99 | bias_bool, 100 | ): 101 | channel_dist = 1000 102 | depth_dist = 40 103 | depth_steps = int(np.ceil((depth - kernel_size_depth + 1) / stride_depth)) 104 | if bias_bool is False: 105 | bias_value = 0 106 | bias = None 107 | else: 108 | bias_value = 1.0 109 | bias = np.array([1]) 110 | 111 | # input tensor 112 | array = self.get_array() 113 | array = np.expand_dims( 114 | np.stack( 115 | [ 116 | j * channel_dist 117 | + np.stack([i * depth_dist + array for i in range(depth)]) 118 | for j in range(in_channels) 119 | ] 120 | ), 121 | 0, 122 | ) 123 | tensor = torch.FloatTensor(array) 124 | 125 | # expected output tensor 126 | if kernel_size_hex == 1: 127 | conv2d_array = self.get_array_conv2d_size1_stride1() 128 | n_neighbours = self.get_n_neighbors_size1() 129 | elif kernel_size_hex == 2: 130 | conv2d_array = self.get_array_conv2d_size2_stride1() 131 | n_neighbours = self.get_n_neighbors_size2() 132 | convolved_array = [] 133 | for dstep in range(depth_steps): 134 | layer_array = np.sum( 135 | np.stack( 136 | [ 137 | ( 138 | channel * channel_dist 139 | + ((dstep * stride_depth) + dsize) * depth_dist 140 | ) 141 | * n_neighbours 142 | + conv2d_array 143 | for dsize in range(kernel_size_depth) 144 | for channel in range(in_channels) 145 | ] 146 | ), 147 | 0, 148 | ) 149 | if stride_hex == 2: 150 | layer_array = self.get_array_stride_2(layer_array) 151 | elif stride_hex == 3: 152 | layer_array = self.get_array_stride_3(layer_array) 153 | convolved_array.append(layer_array) 154 | convolved_array = np.expand_dims( 155 | np.expand_dims(np.stack(convolved_array), 0), 0 156 | ) 157 | convolved_tensor = torch.FloatTensor(convolved_array) + bias_value 158 | 159 | # output tensor of test method 160 | if kernel_size_hex == 1: 161 | kernel = [ 162 | np.ones((1, in_channels, kernel_size_depth, 3, 1)), 163 | np.ones((1, in_channels, kernel_size_depth, 2, 2)), 164 | ] 165 | elif kernel_size_hex == 2: 166 | kernel = [ 167 | np.ones((1, in_channels, kernel_size_depth, 5, 1)), 168 | np.ones((1, in_channels, kernel_size_depth, 4, 2)), 169 | np.ones((1, in_channels, kernel_size_depth, 3, 2)), 170 | ] 171 | 172 | conv3d = hex.Conv3d_CustomKernel(kernel, (stride_depth, stride_hex), bias) 173 | 174 | return conv3d(tensor), convolved_tensor 175 | 176 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_1_bias_False(self): 177 | in_channels = 1 178 | depth = 1 179 | kernel_size_depth = 1 180 | kernel_size_hex = 1 181 | stride_depth = 1 182 | stride_hex = 1 183 | bias = False 184 | 185 | test_ouput, expectation = self.get_tensors( 186 | in_channels, 187 | depth, 188 | kernel_size_depth, 189 | kernel_size_hex, 190 | stride_depth, 191 | stride_hex, 192 | bias, 193 | ) 194 | 195 | assert torch.equal(test_ouput, expectation) 196 | 197 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_2_bias_False(self): 198 | in_channels = 1 199 | depth = 1 200 | kernel_size_depth = 1 201 | kernel_size_hex = 1 202 | stride_depth = 1 203 | stride_hex = 2 204 | bias = False 205 | 206 | test_ouput, expectation = self.get_tensors( 207 | in_channels, 208 | depth, 209 | kernel_size_depth, 210 | kernel_size_hex, 211 | stride_depth, 212 | stride_hex, 213 | bias, 214 | ) 215 | 216 | assert torch.equal(test_ouput, expectation) 217 | 218 | def test_in_channels_1_depth_1_kernel_size_1_1_stride_1_3_bias_False(self): 219 | in_channels = 1 220 | depth = 1 221 | kernel_size_depth = 1 222 | kernel_size_hex = 1 223 | stride_depth = 1 224 | stride_hex = 3 225 | bias = False 226 | 227 | test_ouput, expectation = self.get_tensors( 228 | in_channels, 229 | depth, 230 | kernel_size_depth, 231 | kernel_size_hex, 232 | stride_depth, 233 | stride_hex, 234 | bias, 235 | ) 236 | 237 | assert torch.equal(test_ouput, expectation) 238 | 239 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_1_bias_False(self): 240 | in_channels = 1 241 | depth = 1 242 | kernel_size_depth = 1 243 | kernel_size_hex = 2 244 | stride_depth = 1 245 | stride_hex = 1 246 | bias = False 247 | 248 | test_ouput, expectation = self.get_tensors( 249 | in_channels, 250 | depth, 251 | kernel_size_depth, 252 | kernel_size_hex, 253 | stride_depth, 254 | stride_hex, 255 | bias, 256 | ) 257 | 258 | assert torch.equal(test_ouput, expectation) 259 | 260 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_2_bias_False(self): 261 | in_channels = 1 262 | depth = 1 263 | kernel_size_depth = 1 264 | kernel_size_hex = 2 265 | stride_depth = 1 266 | stride_hex = 2 267 | bias = False 268 | 269 | test_ouput, expectation = self.get_tensors( 270 | in_channels, 271 | depth, 272 | kernel_size_depth, 273 | kernel_size_hex, 274 | stride_depth, 275 | stride_hex, 276 | bias, 277 | ) 278 | 279 | assert torch.equal(test_ouput, expectation) 280 | 281 | def test_in_channels_1_depth_1_kernel_size_1_2_stride_1_3_bias_False(self): 282 | in_channels = 1 283 | depth = 1 284 | kernel_size_depth = 1 285 | kernel_size_hex = 2 286 | stride_depth = 1 287 | stride_hex = 3 288 | bias = False 289 | 290 | test_ouput, expectation = self.get_tensors( 291 | in_channels, 292 | depth, 293 | kernel_size_depth, 294 | kernel_size_hex, 295 | stride_depth, 296 | stride_hex, 297 | bias, 298 | ) 299 | 300 | assert torch.equal(test_ouput, expectation) 301 | 302 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_1_1_bias_False(self): 303 | in_channels = 1 304 | depth = 9 305 | kernel_size_depth = 1 306 | kernel_size_hex = 1 307 | stride_depth = 1 308 | stride_hex = 1 309 | bias = False 310 | 311 | test_ouput, expectation = self.get_tensors( 312 | in_channels, 313 | depth, 314 | kernel_size_depth, 315 | kernel_size_hex, 316 | stride_depth, 317 | stride_hex, 318 | bias, 319 | ) 320 | 321 | assert torch.equal(test_ouput, expectation) 322 | 323 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_2_1_bias_False(self): 324 | in_channels = 1 325 | depth = 9 326 | kernel_size_depth = 1 327 | kernel_size_hex = 1 328 | stride_depth = 2 329 | stride_hex = 1 330 | bias = False 331 | 332 | test_ouput, expectation = self.get_tensors( 333 | in_channels, 334 | depth, 335 | kernel_size_depth, 336 | kernel_size_hex, 337 | stride_depth, 338 | stride_hex, 339 | bias, 340 | ) 341 | 342 | assert torch.equal(test_ouput, expectation) 343 | 344 | def test_in_channels_1_depth_9_kernel_size_1_1_stride_3_1_bias_False(self): 345 | in_channels = 1 346 | depth = 9 347 | kernel_size_depth = 1 348 | kernel_size_hex = 1 349 | stride_depth = 3 350 | stride_hex = 1 351 | bias = False 352 | 353 | test_ouput, expectation = self.get_tensors( 354 | in_channels, 355 | depth, 356 | kernel_size_depth, 357 | kernel_size_hex, 358 | stride_depth, 359 | stride_hex, 360 | bias, 361 | ) 362 | 363 | assert torch.equal(test_ouput, expectation) 364 | 365 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_1_1_bias_False(self): 366 | in_channels = 1 367 | depth = 9 368 | kernel_size_depth = 2 369 | kernel_size_hex = 1 370 | stride_depth = 1 371 | stride_hex = 1 372 | bias = False 373 | 374 | test_ouput, expectation = self.get_tensors( 375 | in_channels, 376 | depth, 377 | kernel_size_depth, 378 | kernel_size_hex, 379 | stride_depth, 380 | stride_hex, 381 | bias, 382 | ) 383 | 384 | assert torch.equal(test_ouput, expectation) 385 | 386 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_1_bias_False(self): 387 | in_channels = 1 388 | depth = 9 389 | kernel_size_depth = 2 390 | kernel_size_hex = 1 391 | stride_depth = 2 392 | stride_hex = 1 393 | bias = False 394 | 395 | test_ouput, expectation = self.get_tensors( 396 | in_channels, 397 | depth, 398 | kernel_size_depth, 399 | kernel_size_hex, 400 | stride_depth, 401 | stride_hex, 402 | bias, 403 | ) 404 | 405 | assert torch.equal(test_ouput, expectation) 406 | 407 | def test_in_channels_1_depth_9_kernel_size_2_1_stride_2_2_bias_False(self): 408 | in_channels = 1 409 | depth = 9 410 | kernel_size_depth = 2 411 | kernel_size_hex = 1 412 | stride_depth = 2 413 | stride_hex = 2 414 | bias = False 415 | 416 | test_ouput, expectation = self.get_tensors( 417 | in_channels, 418 | depth, 419 | kernel_size_depth, 420 | kernel_size_hex, 421 | stride_depth, 422 | stride_hex, 423 | bias, 424 | ) 425 | 426 | assert torch.equal(test_ouput, expectation) 427 | 428 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_1_bias_False(self): 429 | in_channels = 1 430 | depth = 9 431 | kernel_size_depth = 7 432 | kernel_size_hex = 2 433 | stride_depth = 1 434 | stride_hex = 1 435 | bias = False 436 | 437 | test_ouput, expectation = self.get_tensors( 438 | in_channels, 439 | depth, 440 | kernel_size_depth, 441 | kernel_size_hex, 442 | stride_depth, 443 | stride_hex, 444 | bias, 445 | ) 446 | 447 | assert torch.equal(test_ouput, expectation) 448 | 449 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_1_2_bias_False(self): 450 | in_channels = 1 451 | depth = 9 452 | kernel_size_depth = 7 453 | kernel_size_hex = 2 454 | stride_depth = 1 455 | stride_hex = 1 456 | bias = False 457 | 458 | test_ouput, expectation = self.get_tensors( 459 | in_channels, 460 | depth, 461 | kernel_size_depth, 462 | kernel_size_hex, 463 | stride_depth, 464 | stride_hex, 465 | bias, 466 | ) 467 | 468 | assert torch.equal(test_ouput, expectation) 469 | 470 | def test_in_channels_1_depth_9_kernel_size_7_2_stride_2_2_bias_False(self): 471 | in_channels = 1 472 | depth = 9 473 | kernel_size_depth = 7 474 | kernel_size_hex = 2 475 | stride_depth = 2 476 | stride_hex = 1 477 | bias = False 478 | 479 | test_ouput, expectation = self.get_tensors( 480 | in_channels, 481 | depth, 482 | kernel_size_depth, 483 | kernel_size_hex, 484 | stride_depth, 485 | stride_hex, 486 | bias, 487 | ) 488 | 489 | assert torch.equal(test_ouput, expectation) 490 | 491 | def test_in_channels_5_depth_9_kernel_size_3_2_stride_1_1_bias_False(self): 492 | in_channels = 5 493 | depth = 9 494 | kernel_size_depth = 7 495 | kernel_size_hex = 2 496 | stride_depth = 1 497 | stride_hex = 1 498 | bias = False 499 | 500 | test_ouput, expectation = self.get_tensors( 501 | in_channels, 502 | depth, 503 | kernel_size_depth, 504 | kernel_size_hex, 505 | stride_depth, 506 | stride_hex, 507 | bias, 508 | ) 509 | 510 | assert torch.equal(test_ouput, expectation) 511 | 512 | def test_in_channels_5_depth_9_kernel_size_3_2_stride_1_1_bias_True(self): 513 | in_channels = 5 514 | depth = 9 515 | kernel_size_depth = 7 516 | kernel_size_hex = 2 517 | stride_depth = 1 518 | stride_hex = 1 519 | bias = True 520 | 521 | test_ouput, expectation = self.get_tensors( 522 | in_channels, 523 | depth, 524 | kernel_size_depth, 525 | kernel_size_hex, 526 | stride_depth, 527 | stride_hex, 528 | bias, 529 | ) 530 | 531 | assert torch.equal(test_ouput, expectation) 532 | -------------------------------------------------------------------------------- /notebooks/example_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | HexagDLy utilities for illustrative examples. 3 | 4 | """ 5 | 6 | import numpy as np 7 | import numpy.linalg as LA 8 | from scipy.interpolate import griddata 9 | import torch 10 | import torch.utils.data 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | import torch.optim.lr_scheduler as scheduler 16 | import os 17 | import matplotlib.pyplot as plt 18 | import time 19 | 20 | 21 | def put_shape(nx, ny, cx, cy, params): 22 | d = np.zeros((nx, ny)) 23 | i = np.indices((nx, ny)) 24 | i[0] = i[0] - cx 25 | i[1] = i[1] - cy 26 | i = i.astype(float) 27 | i[0] *= 1.73205 / 2 28 | if np.mod(cx, 2) == 0: 29 | i[1][np.mod(cx + 1, 2) :: 2] += 0.5 30 | else: 31 | i[1][np.mod(cx + 1, 2) :: 2] -= 0.5 32 | di = i[0] ** 2 + i[1] ** 2 33 | for t1, t2 in params: 34 | di = np.where(np.logical_and(di >= t2, di <= t1), 1, di) 35 | di = np.where(di > 1.1, 0, di) 36 | return di.transpose() 37 | 38 | 39 | class toy_data: 40 | r"""Object that contains a set of toy images of randomly scattered 41 | hexagonal shapes of a certain kind. 42 | 43 | Args: 44 | shape: str, choose from ... 45 | nx: int, dimension in x 46 | ny: int, dimension in y 47 | nchannels: int, number of input channels ('colour' channels) 48 | nexamples: int, number of images 49 | px: int, center row for shape 50 | py: int, center column for shape 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | shape, 57 | nx=16, 58 | ny=16, 59 | nchannels=1, 60 | nexamples=1, 61 | noisy=None, 62 | px=None, 63 | py=None, 64 | ): 65 | self.shapes = { 66 | "small_hexagon": [(1, 0)], 67 | "medium_hexagon": [(4, 0)], 68 | "snowflake_1": [(3, 0)], 69 | "snowflake_2": [(1, 0), (4.1, 3.9)], 70 | "snowflake_3": [(7, 3)], 71 | "snowflake_4": [(7, 0)], 72 | "double_hex": [(10, 5)], 73 | } 74 | self.nx = nx 75 | self.ny = ny 76 | if noisy: 77 | self.image_data = np.random.normal(0, noisy, (nexamples, nchannels, ny, nx)) 78 | else: 79 | self.image_data = np.zeros((nexamples, nchannels, ny, nx)) 80 | for ie, example in enumerate(self.image_data): 81 | for ic, channel in enumerate(example): 82 | if not px and not py: 83 | cx, cy = int(ny * np.random.random()), int(nx * np.random.random()) 84 | else: 85 | cx, cy = px, py 86 | face = put_shape(self.nx, self.ny, cx, cy, self.shapes[shape]) 87 | self.image_data[ie, ic, :, :] += face 88 | 89 | def to_h5(self, filename): 90 | f = h5py.File(filename + ".h5", "w") 91 | f.create_dataset("image_data", data=self.image_data) 92 | 93 | def to_torch_tensor(self): 94 | return torch.Tensor(self.image_data) 95 | 96 | 97 | ################################################################### 98 | 99 | 100 | class Shape(object): 101 | def __init__(self, nx, ny, scale=3, rotation=False): 102 | self.nx = nx 103 | self.ny = ny 104 | self.X = np.zeros(self.nx * self.ny) 105 | self.Y = np.zeros(self.nx * self.ny) 106 | i = 0 107 | for x in range(self.nx): 108 | for y in range(self.ny): 109 | self.X[i], self.Y[i] = x * np.sqrt(3) / 2, -(y + np.mod(x, 2) * 0.5) 110 | i += 1 111 | self.xmin = np.min(self.X) 112 | self.xmax = np.max(self.X) 113 | self.ymin = np.min(self.Y) 114 | self.ymax = np.max(self.Y) 115 | self.P = np.stack([self.X.flatten(), self.Y.flatten()], axis=1) 116 | self.size = 0.5 117 | self.scale = scale 118 | self.rotation = rotation 119 | 120 | def polar_to_cartesian(self, r, alpha): 121 | x = r * np.cos(alpha) 122 | y = r * np.sin(alpha) 123 | return np.array([x, y]) 124 | 125 | def image_from_points(self, point_list_1, point_list_2): 126 | ind = np.full(len(self.P), False) 127 | for p1, p2 in zip(point_list_1, point_list_2): 128 | pa = p2 - p1 129 | alpha = np.arctan2(pa[1], pa[0]) 130 | pb = self.P - p1 131 | beta = np.arctan2(pb[:, 1], pb[:, 0]) 132 | vlen = LA.norm(pb, axis=1) 133 | dist = np.abs(self.polar_to_cartesian(vlen, beta - alpha)[1]) 134 | 135 | tmp = np.where(dist < self.size, True, False) 136 | xmin = np.min([p1[0], p2[0]]) 137 | xmax = np.max([p1[0], p2[0]]) 138 | if np.abs(xmax - xmin) > 1e-12: 139 | xborder1 = np.where(self.P[:, 0] < xmin, False, True) 140 | xborder2 = np.where(self.P[:, 0] > xmax, False, True) 141 | xborder = np.logical_and(xborder1, xborder2) 142 | else: 143 | xborder = np.full(len(self.P), True) 144 | 145 | ymin = np.min([p1[1], p2[1]]) 146 | ymax = np.max([p1[1], p2[1]]) 147 | if np.abs(ymax - ymin) > 1e-12: 148 | yborder1 = np.where(self.P[:, 1] < ymin, False, True) 149 | yborder2 = np.where(self.P[:, 1] > ymax, False, True) 150 | yborder = np.logical_and(yborder1, yborder2) 151 | else: 152 | yborder = np.full(len(self.P), True) 153 | 154 | border = np.logical_and(xborder, yborder) 155 | tmp = np.logical_and(tmp, border) 156 | ind = np.logical_or(ind, tmp) 157 | return np.where(ind, 1, 0) 158 | 159 | def point_list_for_triangle(self, centre, rotation=0.0): 160 | a1, a2, a3 = -np.pi / 6, np.pi / 2, np.pi * 7 / 6 161 | P1 = self.polar_to_cartesian(self.scale, a1 + rotation) + centre 162 | P2 = self.polar_to_cartesian(self.scale, a2 + rotation) + centre 163 | P3 = self.polar_to_cartesian(self.scale, a3 + rotation) + centre 164 | return [P1, P2, P3], [P2, P3, P1] 165 | 166 | def point_list_for_square(self, centre, rotation=0.0): 167 | a1, a2, a3, a4 = np.pi / 4, np.pi * 3 / 4, -np.pi * 3 / 4, -np.pi / 4 168 | P1 = self.polar_to_cartesian(self.scale, a1 + rotation) + centre 169 | P2 = self.polar_to_cartesian(self.scale, a2 + rotation) + centre 170 | P3 = self.polar_to_cartesian(self.scale, a3 + rotation) + centre 171 | P4 = self.polar_to_cartesian(self.scale, a4 + rotation) + centre 172 | return [P1, P2, P3, P4], [P2, P3, P4, P1] 173 | 174 | def image_triangle(self, centre, rotation): 175 | p1, p2 = self.point_list_for_triangle(centre, rotation) 176 | return self.image_from_points(p1, p2) 177 | 178 | def image_square(self, centre, rotation): 179 | p1, p2 = self.point_list_for_square(centre, rotation) 180 | return self.image_from_points(p1, p2) 181 | 182 | def image_circle(self, centre): 183 | dist = np.abs(np.linalg.norm(self.P - centre, axis=1) - self.scale) 184 | return np.where(dist < self.size, 1, 0) 185 | 186 | def __call__(self, shape="circle"): 187 | x = self.xmin + (self.xmax - self.xmin) * np.random.rand() 188 | y = self.ymin + (self.ymax - self.ymin) * np.random.rand() 189 | if self.rotation: 190 | r = 2 * np.pi * np.random.rand() 191 | else: 192 | r = 0.0 193 | if shape == "circle": 194 | centre = np.array([[x, y]]) 195 | return self.image_circle(centre).reshape((self.nx, self.ny)).T 196 | elif shape == "triangle": 197 | centre = np.array([x, y]) 198 | return ( 199 | self.image_triangle(centre, r + np.pi / 7.5) 200 | .reshape((self.nx, self.ny)) 201 | .T 202 | ) 203 | elif shape == "square": 204 | centre = np.array([x, y]) 205 | return ( 206 | self.image_square(centre, r + np.pi / 3).reshape((self.nx, self.ny)).T 207 | ) 208 | else: 209 | return None 210 | 211 | 212 | class toy_data2: 213 | r"""Object that contains a set of toy images of randomly scattered 214 | hexagonal shapes of a certain kind. 215 | 216 | Args: 217 | shape: str, choose from ... 218 | nx: int, dimension in x 219 | ny: int, dimension in y 220 | nchannels: int, number of input channels ('colour' channels) 221 | nexamples: int, number of images 222 | px: int, center row for shape 223 | py: int, center column for shape 224 | 225 | """ 226 | 227 | def __init__(self, shape, nx=16, ny=16, nchannels=1, nexamples=1, noisy=None): 228 | self.nx = nx 229 | self.ny = ny 230 | self.shape = Shape(nx, ny, (nx + ny) / 6, True) 231 | if noisy: 232 | self.image_data = np.random.normal(0, noisy, (nexamples, nchannels, ny, nx)) 233 | else: 234 | self.image_data = np.zeros((nexamples, nchannels, ny, nx)) 235 | for ie, example in enumerate(self.image_data): 236 | for ic, channel in enumerate(example): 237 | self.image_data[ie, ic, :, :] += self.shape(shape) 238 | 239 | def to_h5(self, filename): 240 | f = h5py.File(filename + ".h5", "w") 241 | f.create_dataset("image_data", data=self.image_data) 242 | 243 | def to_torch_tensor(self): 244 | return torch.Tensor(self.image_data) 245 | 246 | 247 | class toy_dataset: 248 | r"""Object that creates a data set containing different shapes 249 | 250 | Args: 251 | shapes: list of strings with names of different shapes 252 | nperclass: int, number of images of each shape 253 | nx: int, number of columns of pixels 254 | ny: int, number of rows of pixels 255 | nchannels: int, number of channels for each image 256 | 257 | """ 258 | 259 | def __init__(self, shapes, nperclass, nx=16, ny=16, nchannels=1, noisy=None): 260 | self.shapes = shapes 261 | self.image_data = np.zeros((len(shapes) * nperclass, nchannels, ny, nx)) 262 | self.labels = np.zeros(len(shapes) * nperclass) 263 | self.nx = nx 264 | self.ny = ny 265 | self.nchannels = nchannels 266 | self.nperclass = nperclass 267 | self.noisy = noisy 268 | self.square_image_data = None 269 | self.square_benchmark = None 270 | 271 | def create(self): 272 | d = [ 273 | toy_data( 274 | shape, self.nx, self.ny, self.nchannels, self.nperclass, self.noisy 275 | ) 276 | for shape in self.shapes 277 | ] 278 | indices = np.arange(len(self.shapes) * self.nperclass) 279 | np.random.shuffle(indices) 280 | icount = 0 281 | for s, label in zip(d, np.arange(len(self.shapes), dtype=np.int)): 282 | for image in s.image_data: 283 | for ic, c in enumerate(image): 284 | self.image_data[indices[icount], ic] = c 285 | self.labels[indices[icount]] = int(label) 286 | icount += 1 287 | 288 | def convert_to_square(self, scale=1, method="linear"): 289 | t0 = time.time() 290 | 291 | X = np.zeros(self.nx * self.ny) 292 | Y = np.zeros(self.nx * self.ny) 293 | i = 0 294 | for x in range(self.nx): 295 | for y in range(self.ny): 296 | X[i], Y[i] = x * np.sqrt(3) / 2, -(y + np.mod(x, 2) * 0.5) 297 | i += 1 298 | 299 | grid_x, grid_y = np.meshgrid( 300 | np.linspace(0, max(X), scale * self.nx), 301 | np.linspace(0, min(Y), scale * self.ny), 302 | ) 303 | 304 | self.square_image_data = np.zeros( 305 | ( 306 | len(self.shapes) * self.nperclass, 307 | self.nchannels, 308 | scale * self.ny, 309 | scale * self.nx, 310 | ) 311 | ) 312 | for ie, example in enumerate(self.image_data): 313 | for ic, image in enumerate(example): 314 | Z = image[:].flatten("F") 315 | tmp = griddata((X, Y), Z, (grid_x, grid_y), method=method) 316 | tmp -= np.nan_to_num(tmp).min() 317 | tmp /= np.nan_to_num(tmp).max() 318 | tmp = np.nan_to_num(tmp) 319 | self.square_image_data[ie, ic, :, :] += tmp 320 | self.square_benchmark = time.time() - t0 321 | 322 | def to_torch_tensor(self, sampling="hexagon"): 323 | if sampling == "square": 324 | return torch.Tensor(self.square_image_data) 325 | else: 326 | return torch.Tensor(self.image_data) 327 | 328 | def to_dataloader(self, batchsize=8, shuffle=True, sampling="hexagon"): 329 | if sampling == "square": 330 | assert ( 331 | self.square_image_data is not None 332 | ), "No square images, please convert first!" 333 | image_data = self.square_image_data 334 | else: 335 | image_data = self.image_data 336 | data, label = torch.from_numpy(image_data), torch.from_numpy(self.labels) 337 | tensor_dataset = torch.utils.data.TensorDataset(data, label) 338 | dataloader = torch.utils.data.DataLoader( 339 | tensor_dataset, 340 | batch_size=batchsize, 341 | shuffle=shuffle, 342 | num_workers=max(1, os.sysconf("SC_NPROCESSORS_ONLN") // 2), 343 | ) 344 | return dataloader 345 | 346 | 347 | class model: 348 | r"""A toy model CNN 349 | 350 | Args: 351 | train_dataloader: pytorch dataloader with training data 352 | val_dataloader: pytorch dataloader with validation data 353 | net: CNN model 354 | epochs: int, number of epochs to train 355 | 356 | """ 357 | 358 | def __init__(self, train_dataloader, val_dataloader, net, epochs=10): 359 | self.train_dataloader = train_dataloader 360 | self.val_dataloader = val_dataloader 361 | self.net = net 362 | self.epochs = epochs 363 | 364 | def train(self, lr=0.005): 365 | nbts = 16 366 | criterion = nn.CrossEntropyLoss() 367 | optimizer = optim.SGD( 368 | self.net.parameters(), lr=lr, momentum=0.9, weight_decay=0.004 369 | ) 370 | self.tepoch = [] 371 | self.tloss = [] 372 | self.taccu = [] 373 | self.tlr = [] 374 | self.vepoch = [] 375 | self.vloss = [] 376 | self.vaccu = [] 377 | self.train_time = 0 378 | self.scheduler = scheduler.ReduceLROnPlateau( 379 | optimizer, 380 | mode="max", 381 | factor=0.5, 382 | patience=10, 383 | verbose=False, 384 | threshold=1, 385 | threshold_mode="abs", 386 | min_lr=1e-10, 387 | ) 388 | for epoch in range(self.epochs): 389 | print("Epoch %d" % (epoch + 1)) 390 | if torch.cuda.is_available(): 391 | self.net = self.net.cuda() 392 | for dataloader, net_phase, phase in zip( 393 | [self.train_dataloader, self.train_dataloader, self.val_dataloader], 394 | ["train", "eval", "eval"], 395 | ["training", "train_lc", "val_lc"], 396 | ): 397 | if net_phase == "train": 398 | t0 = time.time() 399 | num_batches = len(dataloader) 400 | running_loss = 0.0 401 | total = 0.0 402 | correct = 0.0 403 | batch_counter = 0.0 404 | getattr(self.net, net_phase)() 405 | for i, data in enumerate(dataloader, 0): 406 | inputs, labels = data 407 | inputs, labels = Variable(inputs).float(), Variable(labels).long() 408 | if torch.cuda.is_available(): 409 | inputs, labels = inputs.cuda(), labels.cuda() 410 | optimizer.zero_grad() 411 | outputs = self.net(inputs) 412 | tloss = criterion(outputs, labels) 413 | tloss.backward() 414 | optimizer.step() 415 | running_loss += tloss.item() 416 | total += outputs.data.size()[0] 417 | _, predicted = torch.max(outputs.data, 1) 418 | correct += (predicted == labels.data).sum() 419 | if i % nbts == nbts - 1: 420 | current_epoch = epoch + (batch_counter + 1) / num_batches 421 | current_lr = optimizer.param_groups[0]["lr"] 422 | mean_loss = running_loss / nbts 423 | mean_accuracy = 100 * correct.float() / total 424 | print( 425 | "epoch: %d (%.3f) %s - %5d batches -> mean loss: %.3f, lr: %.3f, mean acc.: %.2f %%" 426 | % ( 427 | epoch + 1, 428 | current_epoch, 429 | phase, 430 | i + 1, 431 | mean_loss, 432 | current_lr, 433 | mean_accuracy, 434 | ) 435 | ) 436 | running_loss = 0.0 437 | total = 0.0 438 | correct = 0.0 439 | if phase == "train_lc": 440 | self.tepoch.append(current_epoch) 441 | self.tloss.append(mean_loss) 442 | self.taccu.append(mean_accuracy) 443 | self.tlr.append(current_lr) 444 | elif phase == "val_lc": 445 | self.vepoch.append(current_epoch) 446 | self.vloss.append(mean_loss) 447 | self.vaccu.append(mean_accuracy) 448 | self.scheduler.step(mean_accuracy) 449 | batch_counter += 1.0 450 | batch_counter = 0.0 451 | if net_phase == "train": 452 | self.train_time += time.time() - t0 453 | self.train_time /= self.epochs 454 | 455 | def save_current(self): 456 | torch.save( 457 | self.net.state_dict(), 458 | str(self.net.__class__.__name__) + "_" + str(self.epochs) + ".ptmodel", 459 | ) 460 | 461 | def load(self, filename): 462 | self.net.load_state_dict(torch.load(filename)) 463 | 464 | def get_lc(self): 465 | return ( 466 | np.array(self.tepoch), 467 | np.array(self.tloss), 468 | np.array(self.taccu), 469 | np.array(self.vepoch), 470 | np.array(self.vloss), 471 | np.array(self.vaccu), 472 | np.array(self.train_time), 473 | ) 474 | 475 | def plot_lc(self, scale_to_time=False): 476 | fig = plt.figure("learning_curves", (7, 7)) 477 | axa = fig.add_subplot(311) 478 | axb = fig.add_subplot(312) 479 | axc = fig.add_subplot(313) 480 | tx_axis = np.array(self.tepoch) 481 | vx_axis = np.array(self.vepoch) 482 | if scale_to_time: 483 | tx_axis *= self.train_time 484 | vx_axis *= self.train_time 485 | axa.plot(vx_axis, self.vaccu, "-", lw=1) 486 | axa.set_ylabel("accuracy [%]", size=15) 487 | axa.tick_params( 488 | axis="both", 489 | which="both", 490 | labelsize=10, 491 | bottom=False, 492 | top=False, 493 | labelbottom=False, 494 | ) 495 | 496 | axb.plot(vx_axis, self.vloss, "-", label=self.net.name, lw=1) 497 | axb.legend() 498 | axb.set_ylabel("loss", size=15) 499 | axb.tick_params( 500 | axis="both", 501 | which="both", 502 | labelsize=10, 503 | bottom=False, 504 | top=False, 505 | labelbottom=False, 506 | ) 507 | 508 | axc.plot(tx_axis, self.tlr, lw=1) 509 | axc.set_yscale("log") 510 | axc.set_ylabel("learning rate", size=15) 511 | if scale_to_time: 512 | axc.set_xlabel("train time [s]", size=15) 513 | else: 514 | axc.set_xlabel("# Epochs", size=15) 515 | axc.tick_params( 516 | axis="both", 517 | which="both", 518 | labelsize=10, 519 | bottom=True, 520 | top=True, 521 | labelbottom=True, 522 | ) 523 | fig.canvas.draw() 524 | plt.show() 525 | -------------------------------------------------------------------------------- /src/hexagdly.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains utilities to set up hexagonal convolution and pooling 3 | kernels in PyTorch. The size of the input is abitrary, whereas the layout 4 | from top to bottom (along tensor index 2) has to be of zig-zag-edge shape 5 | and from left to right (along tensor index 3) of armchair-edge shape as 6 | shown below. 7 | __ __ __ __ __ __ 8 | /11\__/31\__ . . . |11|21|31|41| . . . 9 | \__/21\__/41\ |__|__|__|__| 10 | /12\__/32\__/ . . . _______|\ |12|22|32|42| . . . 11 | \__/22\__/42\ | \ |__|__|__|__| 12 | \__/ \__/ |_______ / 13 | . . . . . |/ . . . . . 14 | . . . . . . . . . . 15 | . . . . . . . . . . 16 | 17 | For more information visit https://github.com/ai4iacts/hexagdly 18 | 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from torch.nn.parameter import Parameter 25 | import numpy as np 26 | 27 | 28 | class HexBase: 29 | def __init__(self): 30 | super(HexBase, self).__init__() 31 | self.hexbase_size = None 32 | self.depth_size = None 33 | self.hexbase_stride = None 34 | self.depth_stride = None 35 | self.input_size_is_known = False 36 | self.odd_columns_slices = [] 37 | self.odd_columns_pads = [] 38 | self.even_columns_slices = [] 39 | self.even_columns_pads = [] 40 | self.dimensions = None 41 | self.combine = None 42 | self.process = None 43 | self.kwargs = dict() 44 | 45 | def shape_for_odd_columns(self, input_size, kernel_number): 46 | slices = [None, None, None, None] 47 | pads = [0, 0, 0, 0] 48 | # left 49 | pads[0] = kernel_number 50 | # right 51 | pads[1] = max( 52 | 0, kernel_number - ((input_size[-1] - 1) % (2 * self.hexbase_stride)) 53 | ) 54 | # top 55 | pads[2] = self.hexbase_size - int(kernel_number / 2) 56 | # bottom 57 | constraint = ( 58 | input_size[-2] 59 | - 1 60 | - int( 61 | (input_size[-2] - 1 - int(self.hexbase_stride / 2)) 62 | / self.hexbase_stride 63 | ) 64 | * self.hexbase_stride 65 | ) 66 | bottom = (self.hexbase_size - int((kernel_number + 1) / 2)) - constraint 67 | if bottom >= 0: 68 | pads[3] = bottom 69 | else: 70 | slices[1] = bottom 71 | 72 | return slices, pads 73 | 74 | def shape_for_even_columns(self, input_size, kernel_number): 75 | slices = [None, None, None, None] 76 | pads = [0, 0, 0, 0] 77 | # left 78 | left = kernel_number - self.hexbase_stride 79 | if left >= 0: 80 | pads[0] = left 81 | else: 82 | slices[2] = -left 83 | # right 84 | pads[1] = max( 85 | 0, 86 | kernel_number 87 | - ((input_size[-1] - 1 - self.hexbase_stride) % (2 * self.hexbase_stride)), 88 | ) 89 | # top 90 | top_shift = -(kernel_number % 2) if (self.hexbase_stride % 2) == 1 else 0 91 | top = ( 92 | (self.hexbase_size - int(kernel_number / 2)) 93 | + top_shift 94 | - int(self.hexbase_stride / 2) 95 | ) 96 | if top >= 0: 97 | pads[2] = top 98 | else: 99 | slices[0] = -top 100 | # bottom 101 | bottom_shift = 0 if (self.hexbase_stride % 2) == 1 else -(kernel_number % 2) 102 | pads[3] = max( 103 | 0, 104 | self.hexbase_size 105 | - int(kernel_number / 2) 106 | + bottom_shift 107 | - ( 108 | (input_size[-2] - int(self.hexbase_stride / 2) - 1) 109 | % self.hexbase_stride 110 | ), 111 | ) 112 | 113 | return slices, pads 114 | 115 | def get_padded_input(self, input, pads): 116 | if self.dimensions == 2: 117 | return nn.ZeroPad2d(tuple(pads))(input) 118 | elif self.dimensions == 3: 119 | return nn.ConstantPad3d(tuple(pads + [0, 0]), 0)(input) 120 | 121 | def get_sliced_input(self, input, slices): 122 | if self.dimensions == 2: 123 | return input[:, :, slices[0] : slices[1], slices[2] : slices[3]] 124 | elif self.dimensions == 3: 125 | return input[:, :, :, slices[0] : slices[1], slices[2] : slices[3]] 126 | 127 | def get_dilation(self, dilation_2d): 128 | if self.dimensions == 2: 129 | return dilation_2d 130 | elif self.dimensions == 3: 131 | return tuple([1] + list(dilation_2d)) 132 | 133 | def get_stride(self): 134 | if self.dimensions == 2: 135 | return (self.hexbase_stride, 2 * self.hexbase_stride) 136 | elif self.dimensions == 3: 137 | return (self.depth_stride, self.hexbase_stride, 2 * self.hexbase_stride) 138 | 139 | def get_ordered_output(self, input, order): 140 | if self.dimensions == 2: 141 | return input[:, :, :, order] 142 | elif self.dimensions == 3: 143 | return input[:, :, :, :, order] 144 | 145 | # general implementation of an operation with a hexagonal kernel 146 | def operation_with_arbitrary_stride(self, input): 147 | assert ( 148 | input.size(-2) - (self.hexbase_stride // 2) >= 0 149 | ), "Too few rows to apply hex conv with the stide that is set" 150 | odd_columns = None 151 | even_columns = None 152 | 153 | for i in range(self.hexbase_size + 1): 154 | dilation_base = (1, 1) if i == 0 else (1, 2 * i) 155 | 156 | if not self.input_size_is_known: 157 | slices, pads = self.shape_for_odd_columns(input.size(), i) 158 | self.odd_columns_slices.append(slices) 159 | self.odd_columns_pads.append(pads) 160 | slices, pads = self.shape_for_even_columns(input.size(), i) 161 | self.even_columns_slices.append(slices) 162 | self.even_columns_pads.append(pads) 163 | if i == self.hexbase_size: 164 | self.input_size_is_known = True 165 | 166 | if odd_columns is None: 167 | odd_columns = self.process( 168 | self.get_padded_input( 169 | self.get_sliced_input(input, self.odd_columns_slices[i]), 170 | self.odd_columns_pads[i], 171 | ), 172 | getattr(self, "kernel" + str(i)), 173 | dilation=self.get_dilation(dilation_base), 174 | stride=self.get_stride(), 175 | **self.kwargs 176 | ) 177 | else: 178 | odd_columns = self.combine( 179 | odd_columns, 180 | self.process( 181 | self.get_padded_input( 182 | self.get_sliced_input(input, self.odd_columns_slices[i]), 183 | self.odd_columns_pads[i], 184 | ), 185 | getattr(self, "kernel" + str(i)), 186 | dilation=self.get_dilation(dilation_base), 187 | stride=self.get_stride(), 188 | ), 189 | ) 190 | 191 | if even_columns is None: 192 | even_columns = self.process( 193 | self.get_padded_input( 194 | self.get_sliced_input(input, self.even_columns_slices[i]), 195 | self.even_columns_pads[i], 196 | ), 197 | getattr(self, "kernel" + str(i)), 198 | dilation=self.get_dilation(dilation_base), 199 | stride=self.get_stride(), 200 | **self.kwargs 201 | ) 202 | else: 203 | even_columns = self.combine( 204 | even_columns, 205 | self.process( 206 | self.get_padded_input( 207 | self.get_sliced_input(input, self.even_columns_slices[i]), 208 | self.even_columns_pads[i], 209 | ), 210 | getattr(self, "kernel" + str(i)), 211 | dilation=self.get_dilation(dilation_base), 212 | stride=self.get_stride(), 213 | ), 214 | ) 215 | 216 | concatenated_columns = torch.cat( 217 | (odd_columns, even_columns), 1 + self.dimensions 218 | ) 219 | 220 | n_odd_columns = odd_columns.size(-1) 221 | n_even_columns = even_columns.size(-1) 222 | if n_odd_columns == n_even_columns: 223 | order = [ 224 | int(i + x * n_even_columns) 225 | for i in range(n_even_columns) 226 | for x in range(2) 227 | ] 228 | else: 229 | order = [ 230 | int(i + x * n_odd_columns) 231 | for i in range(n_even_columns) 232 | for x in range(2) 233 | ] 234 | order.append(n_even_columns) 235 | 236 | return self.get_ordered_output(concatenated_columns, order) 237 | 238 | # a slightly faster, case specific implementation of the hexagonal convolution 239 | def operation_with_single_hexbase_stride(self, input): 240 | columns_mod2 = input.size(-1) % 2 241 | odd_kernels_odd_columns = [] 242 | odd_kernels_even_columns = [] 243 | even_kernels_all_columns = [] 244 | 245 | even_kernels_all_columns = self.process( 246 | self.get_padded_input(input, [0, 0, self.hexbase_size, self.hexbase_size]), 247 | self.kernel0, 248 | stride=(1, 1) if self.dimensions == 2 else (self.depth_stride, 1, 1), 249 | **self.kwargs 250 | ) 251 | if self.hexbase_size >= 1: 252 | odd_kernels_odd_columns = self.process( 253 | self.get_padded_input( 254 | input, [1, columns_mod2, self.hexbase_size, self.hexbase_size - 1] 255 | ), 256 | self.kernel1, 257 | dilation=self.get_dilation((1, 2)), 258 | stride=self.get_stride(), 259 | ) 260 | odd_kernels_even_columns = self.process( 261 | self.get_padded_input( 262 | input, 263 | [0, 1 - columns_mod2, self.hexbase_size - 1, self.hexbase_size], 264 | ), 265 | self.kernel1, 266 | dilation=self.get_dilation((1, 2)), 267 | stride=self.get_stride(), 268 | ) 269 | 270 | if self.hexbase_size > 1: 271 | for i in range(2, self.hexbase_size + 1): 272 | if i % 2 == 0: 273 | even_kernels_all_columns = self.combine( 274 | even_kernels_all_columns, 275 | self.process( 276 | self.get_padded_input( 277 | input, 278 | [ 279 | i, 280 | i, 281 | self.hexbase_size - int(i / 2), 282 | self.hexbase_size - int(i / 2), 283 | ], 284 | ), 285 | getattr(self, "kernel" + str(i)), 286 | dilation=self.get_dilation((1, 2 * i)), 287 | stride=(1, 1) 288 | if self.dimensions == 2 289 | else (self.depth_stride, 1, 1), 290 | ), 291 | ) 292 | else: 293 | x = self.hexbase_size + int((1 - i) / 2) 294 | odd_kernels_odd_columns = self.combine( 295 | odd_kernels_odd_columns, 296 | self.process( 297 | self.get_padded_input( 298 | input, [i, i - 1 + columns_mod2, x, x - 1] 299 | ), 300 | getattr(self, "kernel" + str(i)), 301 | dilation=self.get_dilation((1, 2 * i)), 302 | stride=self.get_stride(), 303 | ), 304 | ) 305 | odd_kernels_even_columns = self.combine( 306 | odd_kernels_even_columns, 307 | self.process( 308 | self.get_padded_input( 309 | input, [i - 1, i - columns_mod2, x - 1, x] 310 | ), 311 | getattr(self, "kernel" + str(i)), 312 | dilation=self.get_dilation((1, 2 * i)), 313 | stride=self.get_stride(), 314 | ), 315 | ) 316 | 317 | odd_kernels_concatenated_columns = torch.cat( 318 | (odd_kernels_odd_columns, odd_kernels_even_columns), 1 + self.dimensions 319 | ) 320 | 321 | n_odd_columns = odd_kernels_odd_columns.size(-1) 322 | n_even_columns = odd_kernels_even_columns.size(-1) 323 | if n_odd_columns == n_even_columns: 324 | order = [ 325 | int(i + x * n_even_columns) 326 | for i in range(n_even_columns) 327 | for x in range(2) 328 | ] 329 | else: 330 | order = [ 331 | int(i + x * n_odd_columns) 332 | for i in range(n_even_columns) 333 | for x in range(2) 334 | ] 335 | order.append(n_even_columns) 336 | 337 | return self.combine( 338 | even_kernels_all_columns, 339 | self.get_ordered_output(odd_kernels_concatenated_columns, order), 340 | ) 341 | 342 | 343 | class Conv2d(HexBase, nn.Module): 344 | r"""Applies a 2D hexagonal convolution` 345 | 346 | Args: 347 | in_channels: int: number of input channels 348 | out_channels: int: number of output channels 349 | kernel_size: int: number of layers with neighbouring pixels 350 | covered by the pooling kernel 351 | stride: int: length of strides 352 | bias: bool: add bias if True (default) 353 | debug: bool: switch to debug mode 354 | False: weights are initalised with 355 | kaiming normal, bias with 0.01 (default) 356 | True: weights / bias are set to 1. 357 | 358 | Examples:: 359 | 360 | >>> conv2d = hexagdly.Conv2d(1,3,2,1) 361 | >>> input = torch.randn(1, 1, 4, 2) 362 | >>> output = conv2d(input) 363 | >>> print(output) 364 | """ 365 | 366 | def __init__( 367 | self, in_channels, out_channels, kernel_size=1, stride=1, bias=True, debug=False 368 | ): 369 | super(Conv2d, self).__init__() 370 | self.in_channels = in_channels 371 | self.out_channels = out_channels 372 | self.hexbase_size = kernel_size 373 | self.hexbase_stride = stride 374 | self.debug = debug 375 | self.bias = bias 376 | self.dimensions = 2 377 | self.process = F.conv2d 378 | self.combine = torch.add 379 | 380 | for i in range(self.hexbase_size + 1): 381 | setattr( 382 | self, 383 | "kernel" + str(i), 384 | Parameter( 385 | torch.Tensor( 386 | out_channels, 387 | in_channels, 388 | 1 + 2 * self.hexbase_size - i, 389 | 1 if i == 0 else 2, 390 | ) 391 | ), 392 | ) 393 | if self.bias: 394 | self.bias_tensor = Parameter(torch.Tensor(out_channels)) 395 | self.kwargs = {"bias": self.bias_tensor} 396 | else: 397 | self.kwargs = {"bias": None} 398 | self.init_parameters(self.debug) 399 | 400 | def init_parameters(self, debug): 401 | if debug: 402 | for i in range(self.hexbase_size + 1): 403 | nn.init.constant_(getattr(self, "kernel" + str(i)), 1) 404 | if self.bias: 405 | nn.init.constant_(getattr(self, "kwargs")["bias"], 1.0) 406 | else: 407 | for i in range(self.hexbase_size + 1): 408 | nn.init.kaiming_normal_(getattr(self, "kernel" + str(i))) 409 | if self.bias: 410 | nn.init.constant_(getattr(self, "kwargs")["bias"], 0.01) 411 | 412 | def forward(self, input): 413 | if self.hexbase_stride == 1: 414 | return self.operation_with_single_hexbase_stride(input) 415 | else: 416 | return self.operation_with_arbitrary_stride(input) 417 | 418 | def __repr__(self): 419 | s = ( 420 | "{name}({in_channels}, {out_channels}, kernel_size={hexbase_size}" 421 | ", stride={hexbase_stride}" 422 | ) 423 | if self.bias is False: 424 | s += ", bias=False" 425 | if self.debug is True: 426 | s += ", debug=True" 427 | s += ")" 428 | return s.format(name=self.__class__.__name__, **self.__dict__) 429 | 430 | 431 | class Conv2d_CustomKernel(HexBase, nn.Module): 432 | r"""Applies a 2D hexagonal convolution with custom kernels` 433 | 434 | Args: 435 | sub_kernels: list: list containing sub-kernels as numpy arrays 436 | stride: int: length of strides 437 | bias: array: numpy array with biases (default: None) 438 | requires_grad: bool: trainable parameters if True (default: False) 439 | debug: bool: If True a kernel of size one with all values 440 | set to 1 will be applied as well as no bias 441 | (default: False) 442 | 443 | Examples:: 444 | 445 | Given in the online repository https://github.com/ai4iacts/hexagdly 446 | """ 447 | 448 | def __init__( 449 | self, sub_kernels=[], stride=1, bias=None, requires_grad=False, debug=False 450 | ): 451 | super(Conv2d_CustomKernel, self).__init__() 452 | self.sub_kernels = sub_kernels 453 | self.bias_array = bias 454 | self.hexbase_stride = stride 455 | self.requires_grad = requires_grad 456 | self.debug = debug 457 | self.dimensions = 2 458 | self.process = F.conv2d 459 | self.combine = torch.add 460 | 461 | self.init_parameters(self.debug) 462 | 463 | def init_parameters(self, debug): 464 | if debug or len(self.sub_kernels) == 0: 465 | print( 466 | "The debug kernel is used for {name}!".format( 467 | name=self.__class__.__name__ 468 | ) 469 | ) 470 | self.sub_kernels = [ 471 | np.array([[[[1], [1], [1]]]]), 472 | np.array([[[[1, 1], [1, 1]]]]), 473 | ] 474 | self.hexbase_size = len(self.sub_kernels) - 1 475 | self.check_sub_kernels() 476 | 477 | for i in range(self.hexbase_size + 1): 478 | setattr( 479 | self, 480 | "kernel" + str(i), 481 | Parameter( 482 | torch.from_numpy(self.sub_kernels[i]).type(torch.FloatTensor), 483 | requires_grad=self.requires_grad, 484 | ), 485 | ) 486 | 487 | if not debug and not self.bias_array is None: 488 | self.check_bias() 489 | self.bias_tensor = Parameter( 490 | torch.from_numpy(self.bias_array).type(torch.FloatTensor), 491 | requires_grad=self.requires_grad, 492 | ) 493 | self.kwargs = {"bias": self.bias_tensor} 494 | self.bias = True 495 | else: 496 | self.bias = False 497 | if not self.bias_array is None: 498 | print( 499 | "{name}: Bias is not used in debug mode!".format( 500 | name=self.__class__.__name__ 501 | ) 502 | ) 503 | 504 | def check_sub_kernels(self): 505 | for i in range(self.hexbase_size + 1): 506 | assert ( 507 | type(self.sub_kernels[i]).__module__ == np.__name__ 508 | ), "sub-kernels must be given as numpy arrays" 509 | assert ( 510 | len(self.sub_kernels[i].shape) == 4 511 | ), "sub-kernels must be of rank 4 for a 2d convolution" 512 | if i == 0: 513 | assert ( 514 | self.sub_kernels[i].shape[3] == 1 515 | ), "first sub-kernel must have only 1 column" 516 | assert ( 517 | self.sub_kernels[i].shape[2] == 2 * self.hexbase_size + 1 518 | ), "first sub-kernel must have 2* (kernel size) + 1 rows" 519 | self.out_channels = self.sub_kernels[i].shape[0] 520 | self.in_channels = self.sub_kernels[i].shape[1] 521 | else: 522 | assert ( 523 | self.sub_kernels[i].shape[3] == 2 524 | ), "sub-kernel {}: all but the first sub-kernel must have 2 columns".format( 525 | i 526 | ) 527 | assert ( 528 | self.sub_kernels[i].shape[2] == 2 * self.hexbase_size + 1 - i 529 | ), "{}. sub-kernel must have 2* (kernel size) + 1 - {} rows".format( 530 | i, i 531 | ) 532 | assert ( 533 | self.sub_kernels[i].shape[0] == self.out_channels 534 | ), "sub-kernel {}: out channels are not consistent".format(i) 535 | assert ( 536 | self.sub_kernels[i].shape[1] == self.in_channels 537 | ), "sub-kernel {}: in channels are not consistent".format(i) 538 | 539 | def check_bias(self): 540 | assert ( 541 | type(self.bias_array).__module__ == np.__name__ 542 | ), "bias must be given as a numpy array" 543 | assert len(self.bias_array.shape) == 1, "bias must be of rank 1" 544 | assert ( 545 | self.bias_array.shape[0] == self.out_channels 546 | ), "bias must have length equal to number of out channels" 547 | 548 | def forward(self, input): 549 | if self.hexbase_stride == 1: 550 | return self.operation_with_single_hexbase_stride(input) 551 | else: 552 | return self.operation_with_arbitrary_stride(input) 553 | 554 | def __repr__(self): 555 | s = ( 556 | "{name}({in_channels}, {out_channels}, kernel_size={hexbase_size}" 557 | ", stride={hexbase_stride}" 558 | ) 559 | if self.bias is False: 560 | s += ", bias=False" 561 | if self.debug is True: 562 | s += ", debug=True" 563 | s += ")" 564 | return s.format(name=self.__class__.__name__, **self.__dict__) 565 | 566 | 567 | class Conv3d(HexBase, nn.Module): 568 | r"""Applies a 3D hexagonal convolution` 569 | 570 | Args: 571 | in_channels: int: number of input channels 572 | out_channels: int: number of output channels 573 | kernel_size: int, tuple: number of layers with neighbouring pixels 574 | covered by the pooling kernel 575 | int: same number of layers in all dimensions 576 | tuple of two ints: 577 | 1st int: layers in depth 578 | 2nd int: layers in hexagonal base 579 | stride: int, tuple: length of strides 580 | int: same lenght of strides in each dimension 581 | tuple of two ints: 582 | 1st int: length of strides in depth 583 | 2nd int: length of strides in hexagonal base 584 | bias: bool: add bias if True (default) 585 | debug: bool: switch to debug mode 586 | False: weights are initalised with 587 | kaiming normal, bias with 0.01 (default) 588 | True: weights / bias are set to 1. 589 | 590 | Examples:: 591 | 592 | >>> conv3d = hexagdly.Conv3d((1,1), (2,2)) 593 | >>> input = torch.randn(1, 1, 6, 5, 4) 594 | >>> output = conv3d(input) 595 | >>> print(output) 596 | """ 597 | 598 | def __init__( 599 | self, in_channels, out_channels, kernel_size=1, stride=1, bias=True, debug=False 600 | ): 601 | super(Conv3d, self).__init__() 602 | self.in_channels = in_channels 603 | self.out_channels = out_channels 604 | if isinstance(kernel_size, int): 605 | self.hexbase_size = kernel_size 606 | self.depth_size = kernel_size 607 | elif isinstance(kernel_size, tuple): 608 | assert len(kernel_size) == 2, "Need a tuple of two ints to set kernel size" 609 | self.hexbase_size = kernel_size[1] 610 | self.depth_size = kernel_size[0] 611 | if isinstance(stride, int): 612 | self.hexbase_stride = stride 613 | self.depth_stride = stride 614 | elif isinstance(stride, tuple): 615 | assert len(stride) == 2, "Need a tuple of two ints to set stride" 616 | self.hexbase_stride = stride[1] 617 | self.depth_stride = stride[0] 618 | self.debug = debug 619 | self.bias = bias 620 | self.dimensions = 3 621 | self.process = F.conv3d 622 | self.combine = torch.add 623 | 624 | for i in range(self.hexbase_size + 1): 625 | setattr( 626 | self, 627 | "kernel" + str(i), 628 | Parameter( 629 | torch.Tensor( 630 | out_channels, 631 | in_channels, 632 | self.depth_size, 633 | 1 + 2 * self.hexbase_size - i, 634 | 1 if i == 0 else 2, 635 | ) 636 | ), 637 | ) 638 | if self.bias: 639 | self.bias_tensor = Parameter(torch.Tensor(out_channels)) 640 | self.kwargs = {"bias": self.bias_tensor} 641 | else: 642 | self.kwargs = {"bias": None} 643 | 644 | self.init_parameters(self.debug) 645 | 646 | def init_parameters(self, debug): 647 | if debug: 648 | for i in range(self.hexbase_size + 1): 649 | nn.init.constant_(getattr(self, "kernel" + str(i)), 1) 650 | if self.bias: 651 | nn.init.constant_(getattr(self, "kwargs")["bias"], 1.0) 652 | else: 653 | for i in range(self.hexbase_size + 1): 654 | nn.init.kaiming_normal_(getattr(self, "kernel" + str(i))) 655 | if self.bias: 656 | nn.init.constant_(getattr(self, "kwargs")["bias"], 0.01) 657 | 658 | def forward(self, input): 659 | if self.hexbase_stride == 1: 660 | return self.operation_with_single_hexbase_stride(input) 661 | else: 662 | return self.operation_with_arbitrary_stride(input) 663 | 664 | def __repr__(self): 665 | s = ( 666 | "{name}({in_channels}, {out_channels}, kernel_size=({depth_size}, {hexbase_size})" 667 | ", stride=({depth_stride}, {hexbase_stride})" 668 | ) 669 | if self.bias is False: 670 | s += ", bias=False" 671 | if self.debug is True: 672 | s += ", debug=True" 673 | s += ")" 674 | return s.format(name=self.__class__.__name__, **self.__dict__) 675 | 676 | 677 | class Conv3d_CustomKernel(HexBase, nn.Module): 678 | r"""Applies a 3D hexagonal convolution with custom kernels` 679 | 680 | Args: 681 | sub_kernels: list: list containing sub-kernels as numpy arrays 682 | stride: stride: int, tuple: length of strides 683 | int: same lenght of strides in each dimension 684 | tuple of two ints: 685 | 1st int: length of strides in depth 686 | 2nd int: length of strides in hexagonal base 687 | requires_grad: bool: trainable parameters if True (default: False) 688 | debug: bool: If True a kernel of size one with all values 689 | set to 1 will be applied as well as no bias 690 | (default: False) 691 | 692 | Examples:: 693 | 694 | Given in the online repository https://github.com/ai4iacts/hexagdly 695 | """ 696 | 697 | def __init__( 698 | self, sub_kernels=[], stride=1, bias=None, requires_grad=False, debug=False 699 | ): 700 | super(Conv3d_CustomKernel, self).__init__() 701 | self.sub_kernels = sub_kernels 702 | self.bias_array = bias 703 | if isinstance(stride, int): 704 | self.hexbase_stride = stride 705 | self.depth_stride = stride 706 | elif isinstance(stride, tuple): 707 | assert len(stride) == 2, "Need a tuple of two ints to set stride" 708 | self.hexbase_stride = stride[1] 709 | self.depth_stride = stride[0] 710 | self.requires_grad = requires_grad 711 | self.debug = debug 712 | self.dimensions = 3 713 | self.process = F.conv3d 714 | self.combine = torch.add 715 | 716 | self.init_parameters(self.debug) 717 | 718 | def init_parameters(self, debug): 719 | if debug or len(self.sub_kernels) == 0: 720 | print( 721 | "The debug kernel is used for {name}!".format( 722 | name=self.__class__.__name__ 723 | ) 724 | ) 725 | self.sub_kernels = [ 726 | np.array([[[[[1], [1], [1]]]]]), 727 | np.array([[[[[1, 1], [1, 1]]]]]), 728 | ] 729 | self.hexbase_size = len(self.sub_kernels) - 1 730 | self.check_sub_kernels() 731 | 732 | for i in range(self.hexbase_size + 1): 733 | setattr( 734 | self, 735 | "kernel" + str(i), 736 | Parameter( 737 | torch.from_numpy(self.sub_kernels[i]).type(torch.FloatTensor), 738 | requires_grad=self.requires_grad, 739 | ), 740 | ) 741 | 742 | if not debug and not self.bias_array is None: 743 | self.check_bias() 744 | self.bias_tensor = Parameter( 745 | torch.from_numpy(self.bias_array).type(torch.FloatTensor), 746 | requires_grad=self.requires_grad, 747 | ) 748 | self.kwargs = {"bias": self.bias_tensor} 749 | self.bias = True 750 | else: 751 | self.bias = False 752 | print("No bias is used for {name}!".format(name=self.__class__.__name__)) 753 | 754 | def check_sub_kernels(self): 755 | for i in range(self.hexbase_size + 1): 756 | assert ( 757 | type(self.sub_kernels[i]).__module__ == np.__name__ 758 | ), "sub-kernels must be given as numpy arrays" 759 | assert ( 760 | len(self.sub_kernels[i].shape) == 5 761 | ), "sub-kernels must be of rank 5 for a 3d convolution" 762 | if i == 0: 763 | assert ( 764 | self.sub_kernels[i].shape[4] == 1 765 | ), "first sub-kernel must have only 1 column" 766 | assert ( 767 | self.sub_kernels[i].shape[3] == 2 * self.hexbase_size + 1 768 | ), "first sub-kernel must have 2* (kernel size) + 1 rows" 769 | self.out_channels = self.sub_kernels[i].shape[0] 770 | self.in_channels = self.sub_kernels[i].shape[1] 771 | self.depth_size = self.sub_kernels[i].shape[2] 772 | else: 773 | assert ( 774 | self.sub_kernels[i].shape[4] == 2 775 | ), "sub-kernel {}: all but the first sub-kernel must have 2 columns".format( 776 | i 777 | ) 778 | assert ( 779 | self.sub_kernels[i].shape[3] == 2 * self.hexbase_size + 1 - i 780 | ), "{}th sub-kernel must have 2* (kernel size) + 1 - {} rows".format( 781 | i, i 782 | ) 783 | assert ( 784 | self.sub_kernels[i].shape[0] == self.out_channels 785 | ), "sub-kernel {}: out channels are not consistent".format(i) 786 | assert ( 787 | self.sub_kernels[i].shape[1] == self.in_channels 788 | ), "sub-kernel {}: out channels are not consistent".format(i) 789 | assert ( 790 | self.sub_kernels[i].shape[2] == self.depth_size 791 | ), "sub-kernel {}: depths are not consistent".format(i) 792 | 793 | def check_bias(self): 794 | assert ( 795 | type(self.bias_array).__module__ == np.__name__ 796 | ), "bias must be given as a numpy array" 797 | assert len(self.bias_array.shape) == 1, "bias must be of rank 1" 798 | assert ( 799 | self.bias_array.shape[0] == self.out_channels 800 | ), "bias must have length equal to number of out channels" 801 | 802 | def forward(self, input): 803 | if self.hexbase_stride == 1: 804 | return self.operation_with_single_hexbase_stride(input) 805 | else: 806 | return self.operation_with_arbitrary_stride(input) 807 | 808 | def __repr__(self): 809 | s = ( 810 | "{name}({in_channels}, {out_channels}, kernel_size=({depth_size}, {hexbase_size})" 811 | ", stride=({depth_stride}, {hexbase_stride})" 812 | ) 813 | if self.bias is False: 814 | s += ", bias=False" 815 | if self.debug is True: 816 | s += ", debug=True" 817 | s += ")" 818 | return s.format(name=self.__class__.__name__, **self.__dict__) 819 | 820 | 821 | class MaxPool2d(HexBase, nn.Module): 822 | r"""Applies a 2D hexagonal max pooling` 823 | 824 | Args: 825 | kernel_size: int: number of layers with neighbouring pixels 826 | covered by the pooling kernel 827 | stride: int: length of strides 828 | 829 | Examples:: 830 | 831 | >>> maxpool2d = hexagdly.MaxPool2d(1,2) 832 | >>> input = torch.randn(1, 1, 4, 2) 833 | >>> output = maxpool2d(input) 834 | >>> print(output) 835 | """ 836 | 837 | def __init__(self, kernel_size=1, stride=1): 838 | super(MaxPool2d, self).__init__() 839 | self.hexbase_size = kernel_size 840 | self.hexbase_stride = stride 841 | self.dimensions = 2 842 | self.process = F.max_pool2d 843 | self.combine = torch.max 844 | 845 | for i in range(self.hexbase_size + 1): 846 | setattr( 847 | self, 848 | "kernel" + str(i), 849 | (1 + 2 * self.hexbase_size - i, 1 if i == 0 else 2), 850 | ) 851 | 852 | def forward(self, input): 853 | if self.hexbase_stride == 1: 854 | return self.operation_with_single_hexbase_stride(input) 855 | else: 856 | return self.operation_with_arbitrary_stride(input) 857 | 858 | def __repr__(self): 859 | s = "{name}(kernel_size={hexbase_size}" ", stride={hexbase_stride})" 860 | return s.format(name=self.__class__.__name__, **self.__dict__) 861 | 862 | 863 | class MaxPool3d(HexBase, nn.Module): 864 | r"""Applies a 3D hexagonal max pooling` 865 | 866 | Args: 867 | kernel_size: int, tuple: number of layers with neighbouring pixels 868 | covered by the pooling kernel 869 | int: same number of layers in all dimensions 870 | tuple of two ints: 871 | 1st int: layers in depth 872 | 2nd int: layers in hexagonal base 873 | stride: int, tuple: length of strides 874 | int: same lenght of strides in each dimension 875 | tuple of two ints: 876 | 1st int: length of strides in depth 877 | 2nd int: length of strides in hexagonal base 878 | 879 | Examples:: 880 | 881 | >>> maxpool3d = hexagdly.MaxPool3d((1,1), (2,2)) 882 | >>> input = torch.randn(1, 1, 6, 5, 4) 883 | >>> output = maxpool3d(input) 884 | >>> print(output) 885 | """ 886 | 887 | def __init__(self, kernel_size=1, stride=1): 888 | super(MaxPool3d, self).__init__() 889 | if isinstance(kernel_size, int): 890 | self.hexbase_size = kernel_size 891 | self.depth_size = kernel_size 892 | elif isinstance(kernel_size, tuple): 893 | assert len(kernel_size) == 2, "Too many parameters" 894 | self.hexbase_size = kernel_size[1] 895 | self.depth_size = kernel_size[0] 896 | if isinstance(stride, int): 897 | self.hexbase_stride = stride 898 | self.depth_stride = stride 899 | elif isinstance(stride, tuple): 900 | assert len(stride) == 2, "Too many parameters" 901 | self.hexbase_stride = stride[1] 902 | self.depth_stride = stride[0] 903 | self.dimensions = 3 904 | self.process = F.max_pool3d 905 | self.combine = torch.max 906 | 907 | for i in range(self.hexbase_size + 1): 908 | setattr( 909 | self, 910 | "kernel" + str(i), 911 | (self.depth_size, 1 + 2 * self.hexbase_size - i, 1 if i == 0 else 2), 912 | ) 913 | 914 | def forward(self, input): 915 | if self.hexbase_stride == 1: 916 | return self.operation_with_single_hexbase_stride(input) 917 | else: 918 | return self.operation_with_arbitrary_stride(input) 919 | 920 | def __repr__(self): 921 | s = ( 922 | "{name}(kernel_size=({depth_size}, {hexbase_size})" 923 | ", stride=({depth_stride}, {hexbase_stride}))" 924 | ) 925 | return s.format(name=self.__class__.__name__, **self.__dict__) 926 | --------------------------------------------------------------------------------